From 415e915fb944d4e21d3a7c54ffeae91ac024335d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Wed, 10 Dec 2025 16:19:13 +0100 Subject: [PATCH 01/68] - Add extra typing with nibabelImage, VectorXd, ScalarType, etc. into several files - use thread_executor instead of internal own executor object in segstats and brainvolstats - use DO_NOT_SAVE_FILE instead of empty string for defaults in segstats.py - update formatting - fix sphinx warnings (argparse extension) - Remove extra Default statement from device argument - Add Fornix Measure for brainvolstats.py - Add path_or_none arg_type in arg_types.py --- FastSurferCNN/utils/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/FastSurferCNN/utils/__init__.py b/FastSurferCNN/utils/__init__.py index 005a63e4..24b5163d 100644 --- a/FastSurferCNN/utils/__init__.py +++ b/FastSurferCNN/utils/__init__.py @@ -22,6 +22,7 @@ "load_config", "logging", "lr_scheduler", + "LTADict", "mapper", "Mask2d", "Mask3d", From 6ead7dbaa4e347c53886a0dbd954ad1c030141b8 Mon Sep 17 00:00:00 2001 From: ClePol Date: Tue, 16 Sep 2025 20:34:20 +0200 Subject: [PATCH 02/68] added corpus callosum module --- CorpusCallosum/README.md | 301 ++++ CorpusCallosum/cc_visualization.py | 120 ++ CorpusCallosum/data/constants.py | 22 + CorpusCallosum/data/fsaverage_cc_template.py | 127 ++ CorpusCallosum/data/fsaverage_centroids.json | 217 +++ CorpusCallosum/data/fsaverage_data.json | 62 + .../data/generate_fsaverage_centroids.py | 122 ++ CorpusCallosum/data/read_write.py | 257 ++++ CorpusCallosum/fastsurfer_cc.py | 483 ++++++ .../localization/localization_inference.py | 373 +++++ .../registration/mapping_helpers.py | 285 ++++ .../segmentation/segmentation_inference.py | 417 ++++++ .../segmentation_postprocessing.py | 132 ++ CorpusCallosum/shape/cc_endpoint_heuristic.py | 182 +++ CorpusCallosum/shape/cc_mesh.py | 1317 +++++++++++++++++ CorpusCallosum/shape/cc_metrics.py | 178 +++ CorpusCallosum/shape/cc_postprocessing.py | 300 ++++ CorpusCallosum/shape/cc_subsegment_contour.py | 988 +++++++++++++ CorpusCallosum/shape/cc_thickness.py | 368 +++++ CorpusCallosum/shape/resample_poly.py | 65 + .../transforms/localization_transforms.py | 76 + .../transforms/segmentation_transforms.py | 105 ++ CorpusCallosum/visualization/visualization.py | 276 ++++ requirements.mac.txt | 6 +- 24 files changed, 6778 insertions(+), 1 deletion(-) create mode 100644 CorpusCallosum/README.md create mode 100644 CorpusCallosum/cc_visualization.py create mode 100644 CorpusCallosum/data/constants.py create mode 100644 CorpusCallosum/data/fsaverage_cc_template.py create mode 100644 CorpusCallosum/data/fsaverage_centroids.json create mode 100644 CorpusCallosum/data/fsaverage_data.json create mode 100644 CorpusCallosum/data/generate_fsaverage_centroids.py create mode 100644 CorpusCallosum/data/read_write.py create mode 100644 CorpusCallosum/fastsurfer_cc.py create mode 100644 CorpusCallosum/localization/localization_inference.py create mode 100644 CorpusCallosum/registration/mapping_helpers.py create mode 100644 CorpusCallosum/segmentation/segmentation_inference.py create mode 100644 CorpusCallosum/segmentation/segmentation_postprocessing.py create mode 100644 CorpusCallosum/shape/cc_endpoint_heuristic.py create mode 100644 CorpusCallosum/shape/cc_mesh.py create mode 100644 CorpusCallosum/shape/cc_metrics.py create mode 100644 CorpusCallosum/shape/cc_postprocessing.py create mode 100644 CorpusCallosum/shape/cc_subsegment_contour.py create mode 100644 CorpusCallosum/shape/cc_thickness.py create mode 100644 CorpusCallosum/shape/resample_poly.py create mode 100644 CorpusCallosum/transforms/localization_transforms.py create mode 100644 CorpusCallosum/transforms/segmentation_transforms.py create mode 100644 CorpusCallosum/visualization/visualization.py diff --git a/CorpusCallosum/README.md b/CorpusCallosum/README.md new file mode 100644 index 00000000..4467ba02 --- /dev/null +++ b/CorpusCallosum/README.md @@ -0,0 +1,301 @@ +# Corpus Callosum Pipeline + +A deep learning-based pipeline for automated segmentation, analysis, and shape analysis of the corpus callosum in brain MRI scans. +Also segments the fornix, localizes the AC and PC and standardizes the orientation of the brain. + +## Overview + +This pipeline combines localization and segmentation deep learning models to: +1. Detect AC (Anterior Commissure) and PC (Posterior Commissure) points +2. Extract and align midplane slices +3. Segment the corpus callosum +4. Perform post-processing for corpus callosum, including thickness analysis, and various shape metrics +5. Generate visualizations and measurements + + +## Directory Structure + +- `weights/` - Trained model weights +- `transforms/` - Image preprocessing transformations +- `shape/` - Shape analysis and post-processing tools +- `registration/` - Tools for image registration and alignment +- `data/` - Template data and IO +- `localization/` - Inference script for AC/PC localization +- `segmentation/` - Inference scripts for CC/FN segmentation + +## Command Line Interfaces + +### Main Pipeline: `fastsurfer_cc.py` + +The main pipeline script performs the complete corpus callosum analysis workflow. + +#### Basic Usage + +```bash +# Using individual file paths +python3 fastsurfer_cc.py --in_mri /path/to/input/mri.mgz --aseg /path/to/input/aseg.mgz --output_dir /path/to/output --verbose + +# Using FastSurfer/FreeSurfer subject directory structure +python3 fastsurfer_cc.py --subject_dir /path/to/freesurfer/subject --verbose +``` + +#### Required Arguments + +Choose one of these input methods: + +**Option 1: Individual files** +- `--in_mri PATH`: Input MRI file path (FreeSurfer-conformed) +- `--aseg PATH`: Input segmentation file path +- `--output_dir PATH`: Directory for output files + +**Option 2: FastSurfer/FreeSurfer subject directory** +- `--subject_dir PATH`: Subject directory containing standard FreeSurfer structure + - Automatically uses `mri/orig.mgz` and `mri/aparc.DKTatlas+aseg.deep.mgz` + - Creates standard output paths in FreeSurfer structure + +#### Optional Arguments + +**General Options:** +- `--verbose`: Enable verbose output and debug plots +- `--debug_output_dir PATH`: Directory for debug outputs + +**Shape Analysis Parameters:** +- `--num_thickness_points INT`: Number of points for thickness estimation (default: 100) +- `--subdivisions FLOAT [FLOAT ...]`: List of subdivision fractions for CC subsegmentation (default: following Hofer-Frahm definition) +- `--subdivision_method {shape,vertical,angular,eigenvector}`: Method for contour subdivision (default: "shape") + - `shape`: Intercallosal subdivision perpendicular to intercallosal line + - `vertical`: Orthogonal to the most anterior and posterior points in AC/PC standardized CC contour + - `angular`: Subdivision based on equally spaced angles (Hampel et al.) + - `eigenvector`: Primary direction (same as FreeSurfer's mri_cc) +- `--contour_smoothing FLOAT`: Gaussian sigma for smoothing during contour detection (default: 1.0) +- `--slice_selection {middle,all,INT}`: Which slices to process (default: "middle") + +**Custom Output Paths:** +- `--upright_volume_path PATH`: Path for upright volume output +- `--segmentation_path PATH`: Path for segmentation output +- `--postproc_results_path PATH`: Path for postprocessing results +- `--cc_markers_path PATH`: Path for CC markers output +- `--upright_lta_path PATH`: Path for upright LTA transform +- `--orient_volume_lta_path PATH`: Path for orientation volume LTA transform +- `--orig_space_segmentation_path PATH`: Path for segmentation in original space +- `--debug_image_path PATH`: Path for debug visualization image + +**Template Saving:** +- `--save_template PATH`: Directory path to save contours.txt and thickness_values.txt files + +#### Examples + +```bash +# Basic analysis with FreeSurfer subject directory +python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 --verbose + +# Custom shape analysis parameters +python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ + --num_thickness_points 150 \ + --subdivisions 0.2 0.4 0.6 0.8 \ + --subdivision_method angular \ + --contour_smoothing 1.5 + +# Process all slices instead of just middle slice +python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ + --slice_selection all + +# Save template files for visualization +python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ + --save_template /data/templates/sub001 +``` + +### Visualization: `cc_visualization.py` + +Creates visualizations of corpus callosum from template files generated by the main pipeline. + +#### Basic Usage + +```bash +# Using contours file +python3 cc_visualization.py --contours /path/to/contours.txt \ + --thickness /path/to/thickness_values.txt \ + --measurement_points /path/to/measurement_points.txt \ + --output_dir /path/to/output + +# Using fsaverage template (no contours file) +python3 cc_visualization.py \ + --thickness /path/to/thickness_values.txt \ + --measurement_points /path/to/measurement_points.txt \ + --output_dir /path/to/output +``` + +#### Required Arguments + +- `--thickness PATH`: Path to thickness_values.txt file +- `--measurement_points PATH`: Path to measurement points file containing original vertex indices +- `--output_dir PATH`: Directory for output files + +#### Optional Arguments + +**Input:** +- `--contours PATH`: Path to contours.txt file (if not provided, uses fsaverage template) + +**Mesh Parameters:** +- `--resolution FLOAT`: Resolution in mm for the mesh (default: 1.0) +- `--smooth_iterations INT`: Number of smoothing iterations to apply to the mesh (default: 1) + +**Visualization Options:** +- `--colormap {red_to_blue,blue_to_red,red_to_yellow,yellow_to_red}`: Colormap for thickness visualization (default: "red_to_yellow") +- `--color_range MIN MAX`: Optional fixed range for the colorbar +- `--legend STRING`: Legend for the colorbar (default: "Thickness (mm)") +- `--twoD`: Generate 2D visualization instead of 3D mesh + +#### Colormap Options + +- `red_to_blue`: Red → Orange → Grey → Light Blue → Blue +- `blue_to_red`: Blue → Light Blue → Grey → Orange → Red +- `red_to_yellow`: Red → Yellow → Light Blue → Blue +- `yellow_to_red`: Yellow → Light Blue → Blue → Red + +#### Examples + +```bash +# Basic 3D mesh visualization +python3 cc_visualization.py \ + --thickness /data/templates/sub001/thickness_values.txt \ + --measurement_points /data/templates/sub001/measurement_points.txt \ + --output_dir /data/visualizations/sub001 + +# 2D visualization with custom colormap +python3 cc_visualization.py \ + --thickness /data/templates/sub001/thickness_values.txt \ + --measurement_points /data/templates/sub001/measurement_points.txt \ + --output_dir /data/visualizations/sub001 \ + --twoD \ + --colormap blue_to_red +``` + +## Analysis and Visualization Workflow + +The pipeline supports different analysis modes that determine the type of template data generated and corresponding visualization options: + +### 3D Analysis and Visualization + +When running the main pipeline with `--slice_selection all` and `--save_template`, a complete 3D template is generated: + +```bash +# Generate 3D template data +python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ + --slice_selection all \ + --save_template /data/templates/sub001 +``` + +This creates: +- `contours.txt`: Multi-slice contour data for 3D reconstruction +- `thickness_values.txt`: Thickness measurements across all slices +- `measurement_points.txt`: 3D vertex indices for thickness measurements + +The 3D template can then be visualized using the standard 3D mesh options: + +```bash +# Create 3D mesh visualization +python3 cc_visualization.py \ + --contours /data/templates/sub001/contours.txt \ + --thickness /data/templates/sub001/thickness_values.txt \ + --measurement_points /data/templates/sub001/measurement_points.txt \ + --output_dir /data/visualizations/sub001 +``` + +**3D Analysis Benefits:** +- Generates complete surface meshes (VTK, FreeSurfer formats) +- Enables volumetric thickness analysis +- Supports advanced 3D visualizations with proper surface topology +- Creates FreeSurfer-compatible overlay files for integration with other tools + +### 2D Analysis and Visualization + +When using `--slice_selection middle` (default) or a specific slice number with `--save_template`: + +```bash +# Generate 2D template data (middle slice) +python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ + --slice_selection middle \ + --save_template /data/templates/sub001 + +# Or specific slice +python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ + --slice_selection 5 \ + --save_template /data/templates/sub001 +``` + +This creates template data for a single slice, which should be visualized in 2D mode: + +```bash +# Create 2D visualization +python3 cc_visualization.py \ + --thickness /data/templates/sub001/thickness_values.txt \ + --measurement_points /data/templates/sub001/measurement_points.txt \ + --output_dir /data/visualizations/sub001 \ + --twoD +``` + +**2D Analysis Benefits:** +- Faster processing for single-slice analysis +- 2D visualization is most suitable for displaying downstream statistics + +### Surface Generation Requirements + +**Important:** Complete surface files (VTK, FreeSurfer surface formats, overlay files) are only generated when using `--slice_selection all`. Single-slice analysis cannot produce proper 3D surface topology and will not generate these files. + +**3D Surface Outputs (only with `--slice_selection all`):** +- `cc_mesh.vtk`: Complete 3D surface mesh +- `cc_mesh.fssurf`: FreeSurfer surface format +- `cc_mesh_overlay.curv`: Thickness overlay for FreeSurfer visualization + +**2D Outputs (any slice selection):** +- `cc_mesh_snap.png`: 2D visualization or 3D mesh snapshot +- Standard analysis JSON files with measurements + +### Choosing Analysis Mode + +**Use 3D Analysis (`--slice_selection all`) when:** +- You need complete volumetric analysis +- Surface-based visualization is required +- Integration with FreeSurfer workflows is needed +- Comprehensive thickness mapping across the entire corpus callosum is desired + +**Use 2D Analysis (`--slice_selection middle` or specific slice) when:** +- Traditional single-slice morphometry is sufficient +- Faster processing is preferred +- Focus is on mid-sagittal cross-sectional measurements +- Compatibility with classical corpus callosum studies is needed + +## Outputs + +The pipeline produces the following outputs in the specified output directory: + +### Main Pipeline Outputs + +- `cc_markers.json`: Contains detected landmarks and measurements +- `cc_postproc_results.json`: Enhanced postprocessing results with per-slice analysis +- `orient_volume.lta`: Transformation matrix for orientation standardization (AC at origin, PC on anterior-posterior axis) +- `upright.lta`: Transformation matrix for midplane alignment (midsagittal plane cuts brain into hemispheres) +- `upright_volume.mgz`: Original volume mapped with upright.lta +- `segmentation.mgz`: Corpus callosum segmentation on midsagittal plane in upright_volume.mgz space +- `segmentation_orig_space.mgz`: Corpus callosum segmentation in original image orientation +- `cc_postprocessing.png`: Visualization of corpus callosum segmentation and thickness analysis + +### Template Files (when --save_template is used) + +- `contours.txt`: Corpus callosum contour coordinates +- `thickness_values.txt`: Thickness measurements at each point +- `measurement_points.txt`: Original vertex indices where thickness was measured + +### Visualization Outputs + +**3D Mode Outputs (default, when `--twoD` is not specified):** +- `cc_mesh.vtk`: VTK format mesh file for 3D visualization +- `cc_mesh.fssurf`: FreeSurfer surface format for integration with FreeSurfer tools +- `cc_mesh_overlay.curv`: FreeSurfer overlay file containing thickness values +- `cc_mesh.html`: Interactive 3D mesh visualization in HTML format +- `cc_mesh_snap.png`: Snapshot image of the 3D mesh +- `midslice_2d.png`: 2D visualization of the middle slice contour with thickness + +**2D Mode Outputs (when `--twoD` is specified):** +- `cc_thickness_2d.png`: 2D contour visualization with thickness colormap \ No newline at end of file diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py new file mode 100644 index 00000000..fbae0946 --- /dev/null +++ b/CorpusCallosum/cc_visualization.py @@ -0,0 +1,120 @@ +import argparse +from pathlib import Path +import numpy as np +from CorpusCallosum.data.fsaverage_cc_template import load_fsaverage_cc_template + + +from CorpusCallosum.shape.cc_mesh import CC_Mesh + +def options_parse() -> argparse.Namespace: + """Parse command line arguments for the visualization pipeline. + """ + parser = argparse.ArgumentParser(description="Visualize corpus callosum from template files.") + parser.add_argument("--contours", type=str, required=False, help="Path to contours.txt file", default=None) + parser.add_argument("--thickness", type=str, required=True, help="Path to thickness_values.txt file") + parser.add_argument("--measurement_points", type=str, required=True, + help="Path to measurement points file containing the original vertex indices where thickness was measured") + parser.add_argument("--output_dir", type=str, required=True, help="Directory for output files") + parser.add_argument("--resolution", type=float, default=1.0, help="Resolution in mm for the mesh") + parser.add_argument("--smooth_iterations", type=int, default=1, help="Number of smoothing iterations to apply to the mesh") + parser.add_argument("--colormap", type=str, default="red_to_yellow", + choices=["red_to_blue", "blue_to_red", "red_to_yellow", "yellow_to_red"], + help="Colormap to use for thickness visualization") + parser.add_argument("--color_range", type=float, nargs=2, default=None, + metavar=('MIN', 'MAX'), + help="Optional fixed range for the colorbar (min max)") + parser.add_argument("--legend", type=str, default="Thickness (mm)", help="Legend for the colorbar") + parser.add_argument("--twoD", action="store_true", help="Generate 2D visualization instead of 3D mesh") + + args = parser.parse_args() + + # Create output directory if it doesn't exist + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + return args + + +def main(contours_path: str | Path | None, thickness_path: str | Path, measurement_points_path: str | Path, + output_dir: str | Path, resolution: float = 1.0, smooth_iterations: int = 1, + colormap: str = "red_to_yellow", color_range: tuple[float, float] | None = None, + legend: str | None = None, twoD: bool = False) -> None: + """Main function to visualize corpus callosum from template files. + + This function: + 1. Loads contours and thickness values from template files + 2. Creates a CC_Mesh object + 3. Generates and saves visualizations + + Args: + contours_path: Path to contours.txt file + thickness_path: Path to thickness_values.txt file + measurement_points_path: Path to file containing the original vertex indices where thickness was measured + output_dir: Directory for output files + resolution: Resolution in mm for the mesh + smooth_iterations: Number of smoothing iterations to apply to the mesh + colormap: Which colormap to use. Options are: + - "red_to_blue": Red -> Orange -> Grey -> Light Blue -> Blue + - "blue_to_red": Blue -> Light Blue -> Grey -> Orange -> Red + - "red_to_yellow": Red -> Yellow -> Light Blue -> Blue + - "yellow_to_red": Yellow -> Light Blue -> Blue -> Red + color_range: Optional tuple of (min, max) to set fixed color range for the colorbar + twoD: If True, generate 2D visualization instead of 3D mesh + """ + # Convert paths to Path objects + contours_path = Path(contours_path) if contours_path is not None else None + thickness_path = Path(thickness_path) + measurement_points_path = Path(measurement_points_path) + output_dir = Path(output_dir) + + # Load data and create mesh + cc_mesh = CC_Mesh(num_slices=1) # Will be resized when loading data + + if contours_path is not None: + cc_mesh.load_contours(str(contours_path)) + else: + cc_contour, anterior_endpoint_idx, posterior_endpoint_idx = load_fsaverage_cc_template() + cc_mesh.contours[0] = np.stack(cc_contour).T + cc_mesh.start_end_idx[0] = [anterior_endpoint_idx, posterior_endpoint_idx] + + + cc_mesh.load_thickness_values(str(thickness_path), str(measurement_points_path)) + cc_mesh.set_resolution(resolution) + + if twoD: + #cc_mesh.smooth_contour(contour_idx=0, window_size=5) + cc_mesh.plot_cc_contour_with_levelsets(contour_idx=0, levelpaths=None, title=None, save_path=str(output_dir / 'cc_thickness_2d.png'), colorbar=True) + else: + cc_mesh.fill_thickness_values() + # Create and process mesh + cc_mesh.create_mesh(smooth=smooth_iterations, closed=False) + + + + # Generate visualizations + cc_mesh.plot_mesh(colormap=colormap, color_range=color_range, thickness_overlay=True, show_contours=False, show_mesh_edges=True, legend=legend) + cc_mesh.plot_mesh(str(output_dir / 'cc_mesh.html'), thickness_overlay=True) + + cc_mesh.plot_cc_contour_with_levelsets(contour_idx=len(cc_mesh.contours)//2, save_path=str(output_dir / 'midslice_2d.png')) + + cc_mesh.to_fs_coordinates() + cc_mesh.write_vtk(str(output_dir / 'cc_mesh.vtk')) + cc_mesh.write_fssurf(str(output_dir / 'cc_mesh.fssurf')) + cc_mesh.write_overlay(str(output_dir / 'cc_mesh_overlay.curv')) + cc_mesh.snap_cc_picture(str(output_dir / 'cc_mesh_snap.png')) + + +if __name__ == "__main__": + options = options_parse() + main_args = { + 'contours_path': options.contours, + 'thickness_path': options.thickness, + 'measurement_points_path': options.measurement_points, + 'output_dir': options.output_dir, + 'resolution': options.resolution, + 'smooth_iterations': options.smooth_iterations, + 'colormap': options.colormap, + 'color_range': options.color_range, + 'legend': options.legend, + 'twoD': options.twoD + } + main(**main_args) \ No newline at end of file diff --git a/CorpusCallosum/data/constants.py b/CorpusCallosum/data/constants.py new file mode 100644 index 00000000..9c70642e --- /dev/null +++ b/CorpusCallosum/data/constants.py @@ -0,0 +1,22 @@ +from pathlib import Path + +### Constants +WEIGHTS_PATH = Path(__file__).parent.parent / "weights" +FSAVERAGE_CENTROIDS_PATH = Path(__file__).parent / "fsaverage_centroids.json" +FSAVERAGE_DATA_PATH = Path(__file__).parent / "fsaverage_data.json" # Contains both affine and header +FSAVERAGE_MIDDLE = 128 # Middle slice index in fsaverage space +CC_LABEL = 192 # Label value for corpus callosum in segmentation +FORNIX_LABEL = 250 # Label value for fornix in segmentation + + +STANDARD_OUTPUT_PATHS = { + "upright_volume": "mri/upright_volume.mgz", + "segmentation": "mri/cc_segmentation.mgz", + "postproc_results": "stats/cc_postproc_results.json", + "cc_markers": "stats/cc_markers.json", + "upright_lta": "transforms/upright.lta", + "orient_volume_lta": "transforms/orient_volume.lta", + "orig_space_segmentation": "mri/segmentation_orig_space.mgz", + "debug_image": "stats/cc_postprocessing.png", + #"qc_view": "stats/qc_view.png" +} \ No newline at end of file diff --git a/CorpusCallosum/data/fsaverage_cc_template.py b/CorpusCallosum/data/fsaverage_cc_template.py new file mode 100644 index 00000000..0f8329db --- /dev/null +++ b/CorpusCallosum/data/fsaverage_cc_template.py @@ -0,0 +1,127 @@ +import nibabel as nib +import matplotlib.pyplot as plt +from shape.cc_postprocessing import process_slice +from pathlib import Path +import os +from scipy import ndimage +import numpy as np + +def smooth_contour(contour, window_size=5): + """ + Smooth a contour using a moving average filter + + Parameters: + ----------- + contour : tuple of arrays + The contour coordinates (x, y) + window_size : int + Size of the smoothing window + + Returns: + -------- + tuple of arrays + The smoothed contour coordinates (x, y) + """ + x, y = contour + + # Ensure the window size is odd + if window_size % 2 == 0: + window_size += 1 + + # Create a padded version of the arrays to handle the edges + x_padded = np.pad(x, (window_size//2, window_size//2), mode='wrap') + y_padded = np.pad(y, (window_size//2, window_size//2), mode='wrap') + + # Apply moving average + x_smoothed = np.zeros_like(x) + y_smoothed = np.zeros_like(y) + + for i in range(len(x)): + x_smoothed[i] = np.mean(x_padded[i:i+window_size]) + y_smoothed[i] = np.mean(y_padded[i:i+window_size]) + + return (x_smoothed, y_smoothed) + +def load_fsaverage_cc_template(): + # smooth outside contour + # Apply smoothing to the outside contour using a moving average + + + freesurfer_home = Path(os.environ['FREESURFER_HOME']) + + if not freesurfer_home.exists(): + raise EnvironmentError(f"FREESURFER_HOME environment variable is not set correctly or does not exist: {freesurfer_home}, either provide your own template or set the FREESURFER_HOME environment variable") + + fsaverage_seg_path = freesurfer_home / 'subjects' / 'fsaverage' / 'mri' / 'aparc+aseg.mgz' + fsaverage_seg = nib.load(fsaverage_seg_path) + segmentation = fsaverage_seg.get_fdata() + + PC = np.array([131, 99]) + AC = np.array([135, 130]) + + + midslice = segmentation.shape[0]//2 +1 + + cc_mask = segmentation[midslice] == 251 + cc_mask |= segmentation[midslice] == 252 + cc_mask |= segmentation[midslice] == 253 + cc_mask |= segmentation[midslice] == 254 + cc_mask |= segmentation[midslice] == 255 + + # Smooth the CC mask to reduce noise and irregularities + + # Apply binary closing to fill small holes + cc_mask_smoothed = ndimage.binary_closing(cc_mask, structure=np.ones((3, 3))) + + # Apply binary opening to remove small isolated pixels + cc_mask_smoothed = ndimage.binary_opening(cc_mask_smoothed, structure=np.ones((2, 2))) + + # Apply Gaussian smoothing and threshold to get a binary mask again + cc_mask_smoothed = ndimage.gaussian_filter(cc_mask_smoothed.astype(float), sigma=0.8) + cc_mask_smoothed = cc_mask_smoothed > 0.5 + + # Use the smoothed mask for further processing + cc_mask = cc_mask_smoothed.astype(int) + cc_mask[cc_mask > 0] = 192 + + output_dict, contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx = process_slice(cc_mask[None], 0, AC, PC, fsaverage_seg.affine, 100, [1/6, 1/2, 2/3, 3/4], "shape", 1.0, verbose=False) + outside_contour = contour_with_thickness[0].T + + + outside_contour[0][anterior_endpoint_idx] -= 55 + outside_contour[0][posterior_endpoint_idx] += 30 + + # Apply smoothing to the outside contour + outside_contour_smoothed = smooth_contour(outside_contour, window_size=11) + outside_contour_smoothed = smooth_contour(outside_contour_smoothed, window_size=15) + outside_contour_smoothed = smooth_contour(outside_contour_smoothed, window_size=30) + outside_contour = outside_contour_smoothed + + + # Plot CC contour with levelsets + + # midline_equidistant = output_dict['midline_equidistant'] + # levelpaths = output_dict['levelpaths'] + # plt.figure(figsize=(12, 8)) + + # plt.plot(outside_contour[0], outside_contour[1], 'k-', linewidth=2) + + # # Plot the midline + # if midline_equidistant is not None: + # midline_x, midline_y = zip(*midline_equidistant) + # plt.plot(midline_x, midline_y, 'r-', linewidth=2, label='Midline') + + # # Plot the level paths + # if levelpaths: + # for i, path in enumerate(levelpaths): + # path_x, path_y = path[:,0], path[:,1] + # plt.plot(path_x, path_y, 'g--', linewidth=1, alpha=0.7, label=f'Level path {i+1}' if i == 0 else "") + # plt.plot(path_x, path_y, 'gx', markersize=4, alpha=0.7) + + # plt.axis('equal') + # plt.title('Corpus Callosum Contour with Levelsets') + # plt.legend(loc='best') + # plt.grid(True, linestyle='--', alpha=0.7) + # plt.show() + + return outside_contour, anterior_endpoint_idx, posterior_endpoint_idx diff --git a/CorpusCallosum/data/fsaverage_centroids.json b/CorpusCallosum/data/fsaverage_centroids.json new file mode 100644 index 00000000..bccf1189 --- /dev/null +++ b/CorpusCallosum/data/fsaverage_centroids.json @@ -0,0 +1,217 @@ +{ + "2": [ + -27.242888317659038, + -22.210776052870685, + 18.546657917012894 + ], + "3": [ + -32.18990180647074, + -16.863336561239265, + 16.015058654310195 + ], + "4": [ + -14.455663189269757, + -13.693461251862885, + 13.7136736214605 + ], + "5": [ + -33.906934306569354, + -22.284671532846716, + -15.821167883211672 + ], + "7": [ + -17.305372931308085, + -53.43157258369229, + -36.01715408448575 + ], + "8": [ + -22.265822784810126, + -64.36629649763144, + -37.674831094198964 + ], + "10": [ + -11.752497096399537, + -19.87584204413473, + 5.165737514518 + ], + "11": [ + -15.034188034188048, + 9.437551695616207, + 6.913427074717404 + ], + "12": [ + -26.366197183098592, + -0.15686274509803866, + -2.091549295774655 + ], + "13": [ + -20.91671388101983, + -5.188668555240795, + -2.4107648725212414 + ], + "14": [ + 0.5832045337454872, + -11.11695002575992, + -3.9433281813498127 + ], + "15": [ + 0.5413500223513665, + -46.56236030397854, + -33.21814930710772 + ], + "16": [ + 0.8273686582297444, + -31.946261594502232, + -31.003755304367417 + ], + "17": [ + -26.088480154888686, + -24.429622458857693, + -15.148886737657307 + ], + "18": [ + -23.90932509015971, + -7.339515713549716, + -20.63575476558475 + ], + "24": [ + 0.6026785714285694, + -20.70535714285714, + 8.040736607142861 + ], + "26": [ + -9.629820051413873, + 10.960154241645256, + -8.786632390745496 + ], + "28": [ + -11.456631660832358, + -16.84694671334111, + -10.32691559704395 + ], + "30": [ + -28.545454545454533, + -3.200000000000003, + -10.181818181818187 + ], + "31": [ + -12.502610966057432, + -12.218015665796344, + 6.30548302872063 + ], + "41": [ + 27.68021284305685, + -21.297671313867227, + 18.84475807220643 + ], + "42": [ + 32.70257488842361, + -15.910019860438453, + 16.482307738602415 + ], + "43": [ + 15.18157827962446, + -13.241715300685101, + 14.257802588175593 + ], + "44": [ + 33.10191082802548, + -17.921443736730367, + -16.980891719745216 + ], + "46": [ + 19.070892410341955, + -53.51368564713019, + -35.67336416710896 + ], + "47": [ + 23.65288732176549, + -64.41682904951904, + -37.19518418854969 + ], + "49": [ + 12.493538246594483, + -19.225986727209218, + 5.663872394923743 + ], + "50": [ + 16.15939771547248, + 9.458463136033231, + 8.239096573208727 + ], + "51": [ + 26.94455762514552, + 0.5477299185099014, + -2.249126891734562 + ], + "52": [ + 22.105321507760536, + -4.939024390243901, + -1.9539911308204125 + ], + "53": [ + 27.74364210135512, + -23.379431965843693, + -14.994987933914985 + ], + "54": [ + 24.942549371633746, + -6.010771992818675, + -20.737881508079 + ], + "58": [ + 9.986789960369876, + 10.424042272126826, + -8.705416116248358 + ], + "60": [ + 12.434200157604408, + -16.41252955082743, + -10.056737588652481 + ], + "62": [ + 30.558139534883722, + -2.581395348837205, + -10.441860465116292 + ], + "63": [ + 12.008567931456554, + -11.022031823745408, + 7.3671970624235 + ], + "77": [ + -13.714285714285722, + -15.714285714285708, + 0.9285714285714306 + ], + "85": [ + 1.466019417475735, + -0.2038834951456323, + -18.466019417475735 + ], + "251": [ + 0.5403535741737073, + -35.800153727901616, + 16.784780937740194 + ], + "252": [ + 0.6063829787234027, + -18.29361702127659, + 24.748936170212772 + ], + "253": [ + 0.5847299813780324, + -2.424581005586589, + 25.815642458100555 + ], + "254": [ + 0.7008849557522154, + 11.998230088495575, + 20.40530973451328 + ], + "255": [ + 0.8761467889908232, + 24.612844036697254, + 5.411009174311928 + ] +} \ No newline at end of file diff --git a/CorpusCallosum/data/fsaverage_data.json b/CorpusCallosum/data/fsaverage_data.json new file mode 100644 index 00000000..9efa2336 --- /dev/null +++ b/CorpusCallosum/data/fsaverage_data.json @@ -0,0 +1,62 @@ +{ + "affine": [ + [ + -1.0, + 0.0, + 0.0, + 128.0 + ], + [ + 0.0, + 0.0, + 1.0, + -128.0 + ], + [ + 0.0, + -1.0, + 0.0, + 128.0 + ], + [ + 0.0, + 0.0, + 0.0, + 1.0 + ] + ], + "header": { + "dims": [ + 256, + 256, + 256 + ], + "delta": [ + 1.0, + 1.0, + 1.0 + ], + "Mdc": [ + [ + -1.0, + 0.0, + 0.0 + ], + [ + 0.0, + 0.0, + 10000000000.0 + ], + [ + 0.0, + -10000000000.0, + 0.0 + ] + ], + "Pxyz_c": [ + 128.0, + -128.0, + 128.0 + ] + } +} \ No newline at end of file diff --git a/CorpusCallosum/data/generate_fsaverage_centroids.py b/CorpusCallosum/data/generate_fsaverage_centroids.py new file mode 100644 index 00000000..46ac01bd --- /dev/null +++ b/CorpusCallosum/data/generate_fsaverage_centroids.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +""" +Script to generate static fsaverage centroids file. + +This script extracts centroids from the fsaverage template segmentation +and saves them to a JSON file for fast loading during pipeline execution. +Run this script once to generate the centroids file. +""" + +import os +import json +from pathlib import Path +import numpy as np +import nibabel as nib +from read_write import get_centroids_from_nib, convert_numpy_to_json_serializable + + +def main(): + """Generate and save fsaverage centroids to a static file.""" + + # Get fsaverage path from FreeSurfer environment + try: + fs_home = Path(os.environ['FREESURFER_HOME']) + if not fs_home.exists(): + raise EnvironmentError(f"FREESURFER_HOME environment variable is not set correctly or does not exist: {fs_home}") + + fsaverage_path = fs_home / 'subjects' / 'fsaverage' + if not fsaverage_path.exists(): + raise EnvironmentError(f"fsaverage path does not exist: {fsaverage_path}") + + fsaverage_aseg_path = fsaverage_path / 'mri' / 'aseg.mgz' + if not fsaverage_aseg_path.exists(): + raise FileNotFoundError(f"fsaverage aseg file does not exist: {fsaverage_aseg_path}") + + except KeyError: + raise EnvironmentError("FREESURFER_HOME environment variable is not set") + + print(f"Loading fsaverage segmentation from: {fsaverage_aseg_path}") + + # Load fsaverage segmentation + fsaverage_nib = nib.load(fsaverage_aseg_path) + + # Extract centroids + print("Extracting centroids from fsaverage...") + centroids_dst = get_centroids_from_nib(fsaverage_nib) + + print(f"Found {len(centroids_dst)} anatomical structures with centroids") + + # Convert to JSON-serializable format + centroids_serializable = convert_numpy_to_json_serializable(centroids_dst) + + # Save centroids to JSON file + centroids_output_path = Path(__file__).parent / "fsaverage_centroids.json" + with open(centroids_output_path, 'w') as f: + json.dump(centroids_serializable, f, indent=2) + + print(f"Fsaverage centroids saved to: {centroids_output_path}") + print(f"Centroids file size: {centroids_output_path.stat().st_size} bytes") + + # Extract and save fsaverage affine matrix and header fields + print("Extracting fsaverage affine matrix and header fields...") + fsaverage_affine = fsaverage_nib.affine.astype(float) # Convert to float for JSON serialization + + # Extract header fields needed for LTA + header = fsaverage_nib.header + dims = [int(x) for x in header.get_data_shape()[:3]] # Convert to int for JSON serialization + delta = [float(x) for x in header.get_zooms()[:3]] # Convert to float for JSON serialization + vox2ras = header.get_vox2ras() + + # Direction cosines matrix (Mdc) - extract rotation part without scaling + delta_diag = np.diag(delta) + # Avoid division by zero by using a small epsilon for zero values + delta_safe = np.where(delta_diag == 0, 1e-10, delta_diag) + Mdc = (vox2ras[:3, :3] / delta_safe).astype(float) # Convert to float for JSON serialization + + Pxyz_c = vox2ras[:3, 3].astype(float) # Convert to float for JSON serialization + + # Combine affine and header data + combined_data = { + "affine": fsaverage_affine.tolist(), # Convert numpy array to list for JSON serialization + "header": { + "dims": dims, + "delta": delta, + "Mdc": Mdc.tolist(), # Convert numpy array to list for JSON serialization + "Pxyz_c": Pxyz_c.tolist() # Convert numpy array to list for JSON serialization + } + } + + # Convert the entire structure to JSON-serializable format to handle any remaining numpy types + combined_data_serializable = convert_numpy_to_json_serializable(combined_data) + + # Save combined data to JSON file + combined_output_path = Path(__file__).parent / "fsaverage_data.json" + with open(combined_output_path, 'w') as f: + json.dump(combined_data_serializable, f, indent=2) + + print(f"Fsaverage affine and header data saved to: {combined_output_path}") + print(f"Combined file size: {combined_output_path.stat().st_size} bytes") + print(f"Affine matrix shape: {fsaverage_affine.shape}") + print(f"Header dims: {dims}, delta: {delta}") + + # Print some statistics + label_ids = list(centroids_dst.keys()) + print(f"Label IDs range: {min(label_ids)} to {max(label_ids)}") + print(f"Sample centroids:") + for label_id in sorted(label_ids)[:5]: + centroid = centroids_dst[label_id] + print(f" Label {label_id}: [{centroid[0]:.2f}, {centroid[1]:.2f}, {centroid[2]:.2f}]") + + print(f"Fsaverage affine matrix:") + print(fsaverage_affine) + + print(f"Fsaverage header fields:") + print(f" dims: {dims}") + print(f" delta: {delta}") + print(f" Mdc shape: {Mdc.shape}") + print(f" Pxyz_c: {Pxyz_c}") + print(f"Combined data structure created successfully") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py new file mode 100644 index 00000000..3db13150 --- /dev/null +++ b/CorpusCallosum/data/read_write.py @@ -0,0 +1,257 @@ +import multiprocessing +import numpy as np +import nibabel as nib + + +def run_in_background(function, debug=False, *args, **kwargs): + """Run a function in the background using multiprocessing. + + This function executes the given function either in a separate process (normal mode) + or in the current process (debug mode). In debug mode, the function is executed + synchronously for easier debugging. + + Args: + function: The function to execute + debug (bool): If True, run synchronously in current process + *args: Positional arguments to pass to the function + **kwargs: Keyword arguments to pass to the function + + Returns: + multiprocessing.Process or None: Process object if running in background, + None if in debug mode + """ + if debug: + function(*args, **kwargs) + process = None + else: + process = multiprocessing.Process(target=function, args=args, kwargs=kwargs) + process.start() + return process + + + +def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: list[int] | None = None) -> dict[int, np.ndarray]: + """Get centroids of segmentation labels in RAS coordinates. + + Calculates the centroid coordinates for each segmentation label in the image. + If label_ids is provided, only calculates centroids for those specific labels. + Coordinates are returned in RAS (Right-Anterior-Superior) coordinate system. + + Args: + seg_img (nib.Nifti1Image): Nibabel image containing segmentation labels + label_ids (list[int] | None): Optional list of specific label IDs to process. + If None, processes all non-zero labels. + + Returns: + If label_ids is None: + dict[int, np.ndarray]: Mapping of label IDs to their centroids (x,y,z) in RAS coordinates + If label_ids is provided: + tuple: Contains: + - dict[int, np.ndarray]: Mapping of found label IDs to their centroids + - list[int]: List of label IDs that were not found in the image + """ + # Get segmentation data and affine + seg_data = seg_img.get_fdata() + vox2ras = seg_img.affine + + # Get unique labels + if label_ids is None: + labels = np.unique(seg_data) + labels = labels[labels > 0] # Exclude background + else: + labels = label_ids + + centroids = {} + ids_not_found = [] + for label in labels: + # Get voxel indices for this label + vox_coords = np.array(np.where(seg_data == label)) + if vox_coords.size == 0: + ids_not_found.append(label) + continue + # Calculate centroid in voxel space + vox_centroid = np.mean(vox_coords, axis=1) + + # Convert to homogeneous coordinates + vox_centroid = np.append(vox_centroid, 1) + + # Transform to RAS coordinates + ras_centroid = vox2ras @ vox_centroid + + # Store without homogeneous coordinate + centroids[int(label)] = ras_centroid[:3] + + if label_ids is not None: + return centroids, ids_not_found + else: + return centroids + + + +def save_nifti_background(io_processes, data, affine, header, filepath): + """Save a NIfTI image in a background process. + + Creates a MGHImage from the provided data and metadata, then saves it to disk + using a background process to avoid blocking the main execution. + + Args: + io_processes (list): List to store background process handles + data (np.ndarray): Image data array + affine (np.ndarray): 4x4 affine transformation matrix + header: NIfTI header object containing metadata + filepath (str): Path where the image should be saved + """ + io_processes.append(run_in_background(nib.save, False, + nib.MGHImage(data, affine, header), filepath)) + + +def convert_numpy_to_json_serializable(obj): + """Convert numpy arrays in nested data structures to JSON serializable format. + + Recursively traverses dictionaries, lists, and numpy arrays, converting numpy arrays + to Python lists and numpy scalars to Python scalars for JSON serialization. + + Args: + obj: Any Python object that may contain numpy arrays (dict, list, np.ndarray, or scalar) + + Returns: + The input object with all numpy arrays converted to lists and numpy scalars to Python scalars + + Example: + >>> data = {'array': np.array([1, 2, 3]), 'nested': {'array': np.array([4, 5])}} + >>> result = convert_numpy_to_json_serializable(data) + >>> # Result: {'array': [1, 2, 3], 'nested': {'array': [4, 5]}} + """ + if isinstance(obj, dict): + return {k: convert_numpy_to_json_serializable(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [convert_numpy_to_json_serializable(item) for item in obj] + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, (np.integer, np.floating)): + # Handle numpy scalar types + return obj.item() + else: + return obj + + +def load_fsaverage_centroids(centroids_path): + """Load fsaverage centroids from static JSON file. + + Loads pre-computed centroids from a static JSON file, avoiding the need to + compute them from the fsaverage segmentation at runtime. + + Args: + centroids_path (str or Path): Path to the JSON file containing centroids + + Returns: + dict[int, np.ndarray]: Mapping of label IDs to their centroids (x,y,z) in RAS coordinates + + Raises: + FileNotFoundError: If the centroids file doesn't exist + json.JSONDecodeError: If the file is not valid JSON + """ + import json + from pathlib import Path + + centroids_path = Path(centroids_path) + if not centroids_path.exists(): + raise FileNotFoundError(f"Fsaverage centroids file not found: {centroids_path}") + + with open(centroids_path, 'r') as f: + centroids_data = json.load(f) + + # Convert string keys back to integers and lists back to numpy arrays + centroids = {} + for label_str, centroid_list in centroids_data.items(): + label_id = int(label_str) + centroids[label_id] = np.array(centroid_list) + + return centroids + + +def load_fsaverage_affine(affine_path): + """Load fsaverage affine matrix from static text file. + + Loads pre-computed affine matrix from a static text file, avoiding the need to + load the fsaverage segmentation at runtime. + + Args: + affine_path (str or Path): Path to the text file containing affine matrix + + Returns: + np.ndarray: 4x4 affine transformation matrix + + Raises: + FileNotFoundError: If the affine file doesn't exist + ValueError: If the file doesn't contain a valid 4x4 matrix + """ + from pathlib import Path + + affine_path = Path(affine_path) + if not affine_path.exists(): + raise FileNotFoundError(f"Fsaverage affine file not found: {affine_path}") + + affine_matrix = np.loadtxt(affine_path) + + if affine_matrix.shape != (4, 4): + raise ValueError(f"Expected 4x4 affine matrix, got shape {affine_matrix.shape}") + + return affine_matrix + + +def load_fsaverage_data(data_path): + """Load fsaverage affine matrix and header fields from static JSON file. + + Loads pre-computed affine matrix and header fields from a static JSON file, + avoiding the need to load the fsaverage segmentation at runtime. + + Args: + data_path (str or Path): Path to the JSON file containing combined data + + Returns: + tuple: Contains: + - affine_matrix (np.ndarray): 4x4 affine transformation matrix + - header_fields (dict): Header fields needed for LTA: + - dims (list[int]): Volume dimensions [x,y,z] + - delta (list[float]): Voxel size in mm [x,y,z] + - Mdc (np.ndarray): 3x3 direction cosines matrix + - Pxyz_c (np.ndarray): RAS center coordinates [x,y,z] + + Raises: + FileNotFoundError: If the data file doesn't exist + json.JSONDecodeError: If the file is not valid JSON + ValueError: If required fields are missing + """ + import json + from pathlib import Path + + data_path = Path(data_path) + if not data_path.exists(): + raise FileNotFoundError(f"Fsaverage data file not found: {data_path}") + + with open(data_path, 'r') as f: + data = json.load(f) + + # Verify required fields + if "affine" not in data: + raise ValueError("Required field 'affine' missing from data file") + if "header" not in data: + raise ValueError("Required field 'header' missing from data file") + + header_fields = ["dims", "delta", "Mdc", "Pxyz_c"] + for field in header_fields: + if field not in data["header"]: + raise ValueError(f"Required header field missing: {field}") + + # Convert lists back to numpy arrays + affine_matrix = np.array(data["affine"]) + header_data = data["header"].copy() + header_data["Mdc"] = np.array(header_data["Mdc"]) + header_data["Pxyz_c"] = np.array(header_data["Pxyz_c"]) + + # Validate affine matrix shape + if affine_matrix.shape != (4, 4): + raise ValueError(f"Expected 4x4 affine matrix, got shape {affine_matrix.shape}") + + return affine_matrix, header_data \ No newline at end of file diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py new file mode 100644 index 00000000..7f2f9f6a --- /dev/null +++ b/CorpusCallosum/fastsurfer_cc.py @@ -0,0 +1,483 @@ +import argparse +import json +import warnings +warnings.filterwarnings("ignore", message="TypedStorage is deprecated") + +from pathlib import Path + +import nibabel as nib +import numpy as np +import torch + +from localization import localization_inference +from segmentation import segmentation_inference, segmentation_postprocessing +from recon_surf import lta +from CorpusCallosum.registration.mapping_helpers import interpolate_midplane, get_mapping_to_standard_space, map_softlabels_to_orig, apply_transform_to_pt, apply_transform_and_map_volume +from CorpusCallosum.shape.cc_postprocessing import process_slices, create_visualization + +from FastSurferCNN.data_loader.conform import is_conform +from recon_surf.align_points import find_rigid +from CorpusCallosum.data.read_write import save_nifti_background, get_centroids_from_nib, convert_numpy_to_json_serializable, run_in_background, load_fsaverage_centroids, load_fsaverage_data + +from CorpusCallosum.data.constants import * + + + + +def options_parse() -> argparse.Namespace: + """Parse command line arguments for the pipeline. + """ + parser = argparse.ArgumentParser() + parser.add_argument("--in_mri", type=str, required=False, help="Input MRI file path. If not provided, defaults to subject_dir/mri/orig.mgz") + parser.add_argument("--aseg", type=str, required=False, help="Input segmentation file path. If not provided, defaults to subject_dir/mri/aparc.DKTatlas+aseg.deep.mgz") + parser.add_argument("--subject_dir", type=str, required=False, help="Subject directory containing standard FreeSurfer structure. Required if --in_mri and --aseg are not both provided.", default=None) + parser.add_argument("--debug_output_dir", type=str, required=False, default=None) + parser.add_argument("--verbose", action="store_true", help="Enable verbose output and debug plots") + + # CC shape arguments + parser.add_argument("--num_thickness_points", type=int, default=100, help="Number of points for thickness estimation.") + parser.add_argument("--subdivisions", type=float, nargs='+', default=[1/6, 1/2, 2/3, 3/4], help="List of subdivision fractions for the corpus callosum subsegmentation.") + parser.add_argument("--subdivision_method", type=str, default="shape", help="Method for contour subdivision. \ + Options: shape (Intercallosal subdivision perpendicular to intercallosal line), vertical \ + (orthogonal to the most anterior and posterior points in the AC/PC standardized CC contour), \ + angular (subdivision based on equally spaced angles, as proposed by Hampel and colleagues), \ + eigenvector (primary direction, same as FreeSurfers mri_cc)", choices=["shape", "vertical", "angular", "eigenvector"]) + parser.add_argument("--contour_smoothing", type=float, default=1.0, help="Gaussian sigma for smoothing during contour detection. Default is 1.0, higher values mean a smoother outline, at the cost of precision.") + parser.add_argument("--slice_selection", type=str, default="middle", help="Which slices to process. Options: 'middle' (default), 'all', or a specific slice number.") + + # Output path arguments + parser.add_argument("--upright_volume_path", type=str, help="Path for upright volume output (default: subject_dir/stats/upright_volume.mgz)", default=None) + parser.add_argument("--segmentation_path", type=str, help="Path for segmentation output (default: subject_dir/stats/cc_segmentation.mgz)", default=None) + parser.add_argument("--postproc_results_path", type=str, help="Path for postprocessing results (default: subject_dir/stats/cc_postproc_results.json)", default=None) + parser.add_argument("--cc_markers_path", type=str, help="Path for CC markers output (default: subject_dir/stats/cc_markers.json)", default=None) + parser.add_argument("--upright_lta_path", type=str, help="Path for upright LTA transform (default: subject_dir/transforms/upright.lta)", default=None) + parser.add_argument("--orient_volume_lta_path", type=str, help="Path for orientation volume LTA transform (default: subject_dir/transforms/orient_volume.lta)", default=None) + parser.add_argument("--orig_space_segmentation_path", type=str, help="Path for segmentation in original space (default: subject_dir/mri/segmentation_orig_space.mgz)", default=None) + parser.add_argument("--debug_image_path", type=str, help="Path for debug visualization image (default: subject_dir/stats/cc_postprocessing.png)", default=None) + + # Template saving argument + parser.add_argument("--save_template", type=str, help="Directory path where to save contours.txt and thickness_values.txt files", default=None) + + args = parser.parse_args() + + # Validation logic: either subject_dir OR both in_mri and aseg must be provided + if not args.subject_dir and (not args.in_mri or not args.aseg): + parser.error("You must specify either --subject_dir OR both --in_mri and --aseg arguments.") + + # If subject_dir is provided, set default paths for missing arguments + if args.subject_dir: + subject_dir_path = Path(args.subject_dir) + + # Create standard FreeSurfer subdirectories + (subject_dir_path / "mri").mkdir(parents=True, exist_ok=True) + (subject_dir_path / "stats").mkdir(parents=True, exist_ok=True) + (subject_dir_path / "transforms").mkdir(parents=True, exist_ok=True) + + if not args.in_mri: + args.in_mri = str(subject_dir_path / "mri" / "orig.mgz") + + if not args.aseg: + args.aseg = str(subject_dir_path / "mri" / "aparc.DKTatlas+aseg.deep.mgz") + + # Set default output paths if not provided + for key, value in STANDARD_OUTPUT_PATHS.items(): + if not getattr(args, f"{key}_path"): + setattr(args, f"{key}_path", str(subject_dir_path / value)) + + # Set output_dir to subject_dir + args.output_dir = str(subject_dir_path) + + + # Create parent directories for all output paths + for path in [args.upright_volume_path, args.segmentation_path, args.postproc_results_path, args.cc_markers_path, args.upright_lta_path, args.orient_volume_lta_path]: + if path is not None: + Path(path).parent.mkdir(parents=True, exist_ok=True) + + return args + + +def centroid_registration(aseg_nib, verbose=False): + """Perform centroid-based registration between subject and fsaverage space. + + Computes a rigid transformation between the subject's segmentation and fsaverage space + by aligning centroids of corresponding anatomical structures. + + Args: + aseg_nib (nib.Nifti1Image): Subject's segmentation image + verbose (bool): Whether to print progress information + + Returns: + tuple: Contains: + - orig_fsaverage_vox2vox: Transformation matrix from original to fsaverage voxel space + - orig_fsaverage_ras2ras: Transformation matrix from original to fsaverage RAS space + - fsaverage_hires_affine: High-resolution fsaverage affine matrix + - fsaverage_header: FSAverage header fields for LTA writing + """ + if verbose: + print("Centroid registration") + + # Load pre-computed fsaverage centroids and data from static files + centroids_dst = load_fsaverage_centroids(FSAVERAGE_CENTROIDS_PATH) + fsaverage_affine, fsaverage_header = load_fsaverage_data(FSAVERAGE_DATA_PATH) + + centroids_mov, ids_not_found = get_centroids_from_nib(aseg_nib, label_ids=list(centroids_dst.keys())) + + # delete not found labels from centroids_mov + for id in ids_not_found: + del centroids_dst[id] + + centroids_mov = np.array(list(centroids_mov.values())).T + centroids_dst = np.array(list(centroids_dst.values())).T + + orig_fsaverage_ras2ras = find_rigid(p_mov=centroids_mov.T, p_dst=centroids_dst.T) + + # make affine that increases resolution to orig resolution + resolution_orig = aseg_nib.header.get_zooms()[0] + resolution_trans = np.eye(4) + resolution_trans[0, 0] = resolution_orig + resolution_trans[1, 1] = resolution_orig + resolution_trans[2, 2] = resolution_orig + + orig_fsaverage_vox2vox = np.linalg.inv(resolution_trans @ fsaverage_affine) @ orig_fsaverage_ras2ras @ aseg_nib.affine + fsaverage_hires_affine = resolution_trans @ fsaverage_affine + + return orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, fsaverage_hires_affine, fsaverage_header + + +def localize_ac_pc(midslices, aseg_nib, orig_fsaverage_vox2vox, model_localization, slices_to_analyze, verbose=False): + """Localize anterior and posterior commissure points in the brain. + + Uses a trained model to detect AC and PC points in mid-sagittal slices, + using the third ventricle as an anatomical reference. + + Args: + midslices (np.ndarray): Array of mid-sagittal slices + aseg_nib (nib.Nifti1Image): Subject's segmentation image + orig_fsaverage_vox2vox (np.ndarray): Transformation matrix to fsaverage space + fsaverage_hires_affine (np.ndarray): High-resolution fsaverage affine matrix + model_localization: Trained model for AC-PC detection + slices_to_analyze (int): Number of slices to process + verbose (bool): Whether to print progress information + + Returns: + tuple: Contains: + - ac_coords (np.ndarray): Coordinates of the anterior commissure + - pc_coords (np.ndarray): Coordinates of the posterior commissure + """ + if verbose: + print("Localization and segmentation inference") + + # get center of third ventricle from aseg and map to fsaverage space + third_ventricle_mask = aseg_nib.get_fdata() == 4 + third_ventricle_center = np.argwhere(third_ventricle_mask).mean(axis=0) + third_ventricle_center_vox = apply_transform_to_pt(third_ventricle_center, orig_fsaverage_vox2vox, inv=False) + + # get 5 mm of slices output with 3 slices per inference + midslices_middle = midslices.shape[0] // 2 + middle_slices_localization = midslices[midslices_middle-slices_to_analyze//2-1:midslices_middle+slices_to_analyze//2+2] + ac_coords, pc_coords = localization_inference.run_inference_on_slice(model_localization, middle_slices_localization, third_ventricle_center_vox[1:]) + + return ac_coords, pc_coords + + +def segment_cc(midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, slices_to_analyze): + """Segment the corpus callosum using a trained model. + + Performs corpus callosum segmentation on mid-sagittal slices using a trained model, + with AC-PC points as anatomical references. Includes post-processing to clean the segmentation. + + Args: + midslices (np.ndarray): Array of mid-sagittal slices + ac_coords (np.ndarray): Anterior commissure coordinates + pc_coords (np.ndarray): Posterior commissure coordinates + aseg_nib (nib.Nifti1Image): Subject's segmentation image + orig_fsaverage_vox2vox (np.ndarray): Transformation matrix to fsaverage space + fsaverage_hires_affine (np.ndarray): High-resolution fsaverage affine matrix + model_segmentation: Trained model for CC segmentation + slices_to_analyze (int): Number of slices to process + verbose (bool): Whether to print progress information + + Returns: + tuple: Contains: + - segmentation (np.ndarray): Binary segmentation of the corpus callosum + - outputs_soft (np.ndarray): Soft segmentation probabilities + """ + # get 5 mm of slices output with 9 slices per inference + midslices_middle = midslices.shape[0] // 2 + middle_slices_segmentation = midslices[midslices_middle-slices_to_analyze//2-4:midslices_middle+slices_to_analyze//2+5] + segmentation, inputs, outputs_avg, outputs_soft = segmentation_inference.run_inference_on_slice(model_segmentation, + middle_slices_segmentation, + AC_center=ac_coords, PC_center=pc_coords, + voxel_size=aseg_nib.header.get_zooms()[0]) + + pre_clean_segmentation = segmentation.copy() + segmentation, cc_volume_mask = segmentation_postprocessing.clean_cc_segmentation(segmentation) + + # print a warning if the cc_volume_mask touches the edge of the segmentation + if np.any(cc_volume_mask[:,0,:]) or np.any(cc_volume_mask[:,-1,:]) or np.any(cc_volume_mask[:,:,0]) or np.any(cc_volume_mask[:,:,-1]): + print("Warning: CC volume mask touches the edge of the segmentation field-of-view, CC might be truncated") + + # get voxels that were removed during cleaning + removed_voxels = pre_clean_segmentation != segmentation + outputs_soft[removed_voxels, 1] = 0 + + return segmentation, outputs_soft + + +def main(in_mri_path: str | Path, aseg_path: str | Path, output_dir: str | Path, slice_selection: str = "middle", + debug_output_dir: str | Path = None, verbose: bool = False, num_thickness_points: int = 100, + subdivisions: list[float] | None = None, subdivision_method: str = "shape", + contour_smoothing: float = 1.0, + upright_volume_path: str | Path = None, segmentation_path: str | Path = None, + postproc_results_path: str | Path = None, cc_markers_path: str | Path = None, + upright_lta_path: str | Path = None, orient_volume_lta_path: str | Path = None, + orig_space_segmentation_path: str | Path = None, debug_image_path: str | Path = None, + save_template: str | Path | None = None) -> None: + """Main pipeline function for corpus callosum analysis. + + This function performs the following steps: + 1. Initializes environment and loads models + 2. Registers input image to fsaverage space + 3. Detects AC and PC points + 4. Segments the corpus callosum + 5. Performs enhanced post-processing analysis + 6. Saves results and visualizations + + Args: + in_mri_path: Path to input MRI file + aseg_path: Path to input segmentation file + output_dir: Directory for output files + slice_selection: Which slices to process ('middle', 'all', or specific slice number) + debug_output_dir: Optional directory for debug outputs + verbose: Flag for verbose output + num_thickness_points: Number of points for thickness estimation + subdivisions: List of subdivision fractions for CC subsegmentation + subdivision_method: Method for contour subdivision + contour_smoothing: Gaussian sigma for smoothing during contour detection + upright_volume_path: Path for upright volume output (default: output_dir/upright_volume.mgz) + segmentation_path: Path for segmentation output (default: output_dir/segmentation.mgz) + postproc_results_path: Path for postprocessing results (default: output_dir/cc_postproc_results.json) + cc_markers_path: Path for CC markers output (default: output_dir/cc_markers.json) + upright_lta_path: Path for upright LTA transform (default: output_dir/upright.lta) + orient_volume_lta_path: Path for orientation volume LTA transform (default: output_dir/orient_volume.lta) + orig_space_segmentation_path: Path for segmentation in original space (default: output_dir/mri/segmentation_orig_space.mgz) + debug_image_path: Path for debug visualization image (default: output_dir/stats/cc_postprocessing.png) + save_template: Directory path where to save contours.txt and thickness_values.txt files + + The function saves multiple outputs to specified paths or default locations in output_dir: + - cc_markers.json: Contains detected landmarks and measurements + - midplane_slices.mgz: Extracted midplane slices + - upright_volume.mgz: Volume aligned to standard orientation + - segmentation.mgz: Corpus callosum segmentation + - cc_postproc_results.json: Enhanced postprocessing results + - Various visualization plots and transformation matrices + """ + + if subdivisions is None: + subdivisions = [1/6, 1/2, 2/3, 3/4] + + # Convert all paths to Path objects + in_mri_path = Path(in_mri_path) + aseg_path = Path(aseg_path) + output_dir = Path(output_dir) + debug_output_dir = Path(debug_output_dir) if debug_output_dir else None + save_template = Path(save_template) if save_template else None + + # Validate subdivision fractions + for i in subdivisions: + if i < 0 or i > 1: + print('Error: Subdivision fractions must be between 0 and 1, but got: ', i) + exit(1) + + #### setup variables + IO_processes = [] + + orig = nib.load(in_mri_path) + + + # 5 mm around the midplane + slices_to_analyze = int(np.ceil(5 / orig.header.get_zooms()[0])) + if slices_to_analyze % 2 == 0: + slices_to_analyze += 1 + + if verbose: + print(f"Segmenting {slices_to_analyze} slices (5 mm width at {orig.header.get_zooms()[0]} mm resolution, center around the mid-sagittal plane)") + + + if not is_conform(orig, conform_vox_size=orig.header.get_zooms()[0]): + print("Error: MRI is not conformed, please run conform.py or mri_convert to conform the image.") + exit(1) + + # load models + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model_localization = localization_inference.load_model(str(Path(WEIGHTS_PATH) / "localization_weights_acpc.pth"), device=device) + model_segmentation = segmentation_inference.load_model(str(Path(WEIGHTS_PATH) / "segmentation_weights_cc_fn.pth"), device=device) + + + aseg_nib = nib.load(aseg_path) + + orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, fsaverage_hires_affine, fsaverage_header = centroid_registration(aseg_nib, verbose) + + if verbose: + print("Interpolating midplane") + + # this is a fast interpolation to not block the main thread + midslices = interpolate_midplane(orig, orig_fsaverage_vox2vox, slices_to_analyze) + + + # start saving upright volume + IO_processes.append(run_in_background(apply_transform_and_map_volume, False, + orig.get_fdata(), orig_fsaverage_vox2vox, fsaverage_hires_affine, None, upright_volume_path, output_size=np.array([256,256,256]))) + + #### do localization and segmentation inference + ac_coords, pc_coords = localize_ac_pc(midslices, aseg_nib, orig_fsaverage_vox2vox, model_localization, slices_to_analyze, verbose) + segmentation, outputs_soft = segment_cc(midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, slices_to_analyze) + + # map soft labels to original space (in parallel because this takes a while) + IO_processes.append(run_in_background(map_softlabels_to_orig, False, + outputs_soft, orig_fsaverage_vox2vox, orig, slices_to_analyze, orig_space_segmentation_path, fsaverage_middle=FSAVERAGE_MIDDLE)) + + # Create a temporary segmentation image with proper affine for enhanced postprocessing + temp_seg_affine = fsaverage_hires_affine @ np.linalg.inv(np.eye(4)) + + # Process slices based on selection mode + slice_results, slice_io_processes = process_slices( + segmentation=segmentation, + slice_selection=slice_selection, + temp_seg_affine=temp_seg_affine, + midslices=midslices, + ac_coords=ac_coords, + pc_coords=pc_coords, + num_thickness_points=num_thickness_points, + subdivisions=subdivisions, + subdivision_method=subdivision_method, + contour_smoothing=contour_smoothing, + output_dir=output_dir, + debug_image_path=debug_image_path, + vox_size=orig.header.get_zooms()[0], + verbose=verbose, + save_template=save_template + ) + IO_processes.extend(slice_io_processes) + + # Get middle slice result for backward compatibility + middle_slice_result = slice_results[len(slice_results)//2] + + # Create enhanced output dictionary with all slice results + per_slice_output_dict = { + 'slices': [convert_numpy_to_json_serializable({ + 'slice_index': result['slice_index'], + 'cc_index': result['cc_index'], + 'circularity': result['circularity'], + 'areas': result['areas'], + 'midline_length': result['midline_length'], + 'thickness': result['thickness'], + 'curvature': result['curvature'], + 'thickness_profile': result['thickness_profile'], + 'total_area': result['total_area'], + 'total_perimeter': result['total_perimeter'] + }) for result in slice_results], + 'slices_in_segmentation': segmentation.shape[0], + 'voxel_size': [float(x) for x in orig.header.get_zooms()], + 'subdivision_method': subdivision_method, + 'num_thickness_points': num_thickness_points, + 'subdivisions': subdivisions, + 'contour_smoothing': contour_smoothing, + 'slice_selection': slice_selection + } + + # Save slice-wise postprocessing results to JSON + with open(postproc_results_path, "w") as f: + json.dump(per_slice_output_dict, f, indent=4) + + if verbose: + print(f"Multiple slice post-processing results saved to {postproc_results_path}") + + ########## Save outputs ########## + + cc_volume = segmentation_postprocessing.get_cc_volume(desired_width_mm=5, cc_mask=segmentation == CC_LABEL, voxel_size=orig.header.get_zooms()) + + # Create backward compatible output_dict for existing pipeline using middle slice + output_dict = { + 'areas': middle_slice_result['areas'], + 'areas_hofer_frahm': middle_slice_result['areas'] if middle_slice_result['split_contours_hofer_frahm'] is not None else middle_slice_result['areas'], + 'thickness': middle_slice_result['thickness'], + 'curvature': middle_slice_result['curvature'], + 'midline_length': middle_slice_result['midline_length'], + 'circularity': middle_slice_result['circularity'], + 'cc_index': middle_slice_result['cc_index'], + 'total_area': middle_slice_result['total_area'], + 'total_perimeter': middle_slice_result['total_perimeter'], + 'thickness_profile': middle_slice_result['thickness_profile'] + } + + # multiply split contour with resolution scale factor for middle slice visualization + split_contours = [split_contour * orig.header.get_zooms()[1] for split_contour in middle_slice_result['split_contours']] + if middle_slice_result['split_contours_hofer_frahm'] is not None: + split_contours_hofer_frahm = [split_contour * orig.header.get_zooms()[1] for split_contour in middle_slice_result['split_contours_hofer_frahm']] + else: + split_contours_hofer_frahm = split_contours # backward compatibility + midline_equidistant = middle_slice_result['midline_equidistant'] * orig.header.get_zooms()[1] + levelpaths = [levelpath * orig.header.get_zooms()[1] for levelpath in middle_slice_result['levelpaths']] + + # Save middle slice visualization + single_slice_result = { + 'split_contours': split_contours, + 'split_contours_hofer_frahm': split_contours_hofer_frahm, + 'midline_equidistant': midline_equidistant, + 'levelpaths': levelpaths + } + IO_processes.append(create_visualization(subdivision_method, single_slice_result, midslices, + output_dir, ac_coords, pc_coords, orig.header.get_zooms()[0], ' (Middle Slice)')) + + # get ac and pc in all spaces + ac_coords_3d = np.hstack((FSAVERAGE_MIDDLE, ac_coords)) + pc_coords_3d = np.hstack((FSAVERAGE_MIDDLE, pc_coords)) + standardized_to_orig_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig = get_mapping_to_standard_space(orig, ac_coords_3d, pc_coords_3d, orig_fsaverage_vox2vox, output_dir) + + + # save segmentation with fitting affine + orig_to_seg = np.eye(4) + orig_to_seg[0, 3] = -FSAVERAGE_MIDDLE+slices_to_analyze//2 + seg_affine = fsaverage_hires_affine + seg_affine = seg_affine @ np.linalg.inv(orig_to_seg) + save_nifti_background(IO_processes, segmentation, seg_affine, orig.header, segmentation_path) + + # write output dict as csv + output_dict["ac_center"] = ac_coords_orig + output_dict["pc_center"] = pc_coords_orig + output_dict["ac_center_oriented_volume"] = ac_coords_standardized + output_dict["pc_center_oriented_volume"] = pc_coords_standardized + output_dict["ac_center_upright"] = ac_coords_3d + output_dict["pc_center_upright"] = pc_coords_3d + output_dict["cc_5mm_volume"] = cc_volume + output_dict["num_slices"] = slices_to_analyze + + # Convert numpy arrays to lists for JSON serialization + output_dict = convert_numpy_to_json_serializable(output_dict) + + with open(cc_markers_path, "w") as f: + json.dump(output_dict, f, indent=4) + + # save lta to fsaverage space + lta.writeLTA(upright_lta_path, orig_fsaverage_ras2ras, aseg_path, aseg_nib.header, 'fsaverage', fsaverage_header) + + # save lta to standardized space (fsaverage + nodding + ac to center) + orig_to_standardized_ras2ras = orig.affine @ np.linalg.inv(standardized_to_orig_vox2vox) @ np.linalg.inv(orig.affine) + lta.writeLTA(orient_volume_lta_path, orig_to_standardized_ras2ras, in_mri_path, orig.header, in_mri_path, orig.header) + + for process in IO_processes: + if process is not None: + process.join() + + +if __name__ == "__main__": + options = options_parse() + main_args = vars(options) + + # Rename keys to match main function parameters + main_args['in_mri_path'] = main_args.pop('in_mri') + main_args['aseg_path'] = main_args.pop('aseg') + main_args['output_dir'] = main_args.pop('subject_dir', '.') + + main(**main_args) diff --git a/CorpusCallosum/localization/localization_inference.py b/CorpusCallosum/localization/localization_inference.py new file mode 100644 index 00000000..9119c678 --- /dev/null +++ b/CorpusCallosum/localization/localization_inference.py @@ -0,0 +1,373 @@ +import time +import torch +import numpy as np +import nibabel as nib +from monai import transforms +from monai.networks.nets import DenseNet as DenseNet_monai + +from transforms.localization_transforms import CropAroundACPCFixedSize + + +def load_model(checkpoint_path, device=None): + """ + Load the trained numerical localization model from checkpoint + + Args: + checkpoint_path: Path to model checkpoint + device: torch device to load model to (defaults to CUDA if available) + + Returns: + model: Loaded model + """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Initialize model architecture (must match training) + model = DenseNet_monai( # densenet201 + spatial_dims=2, + in_channels=3, + out_channels=4, + init_features=64, + growth_rate=32, + block_config=(6, 12, 48, 32), + bn_size=4, + act=("relu", {"inplace": True}), + norm=("batch", {"affine": True}), + dropout_prob=0.2 + ) + + # Load state dict + if isinstance(checkpoint_path, str): + state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True) + if isinstance(state_dict, dict) and 'model_state_dict' in state_dict: + state_dict = state_dict['model_state_dict'] + else: + state_dict = checkpoint_path + + + # model = torch.nn.DataParallel(model) + # model.load_state_dict(state_dict) + # model = model.module + # torch.save(model.state_dict(), '/workspace/weights/localization_weights1.pth') + + model.load_state_dict(state_dict) + model = model.to(device) + model.eval() + return model + +def get_transforms(): + """Get preprocessing transforms for inference""" + tr = [ + # transforms.LoadImaged( + # keys=['image'], + # reader="NibabelReader", + # image_only=True, + # dtype=torch.float32, + # ensure_channel_first=True + # ), + transforms.ScaleIntensityd(keys=['image'], minv=0, maxv=1), + CropAroundACPCFixedSize( + keys=['image'], + fixed_size=(64, 64), + random_translate=0 + ), + ] + return transforms.Compose(tr) + +def preprocess_volume(image_volume, center_pt, transform=None): + """ + Preprocess a volume for inference + + Args: + image_volume: Input volume as numpy array or path to nifti file + transform: Optional custom transform pipeline + + Returns: + preprocessed: Preprocessed image tensor ready for model input + """ + if transform is None: + transform = get_transforms() + + sample = {"image": image_volume, "AC_center": center_pt, "PC_center": center_pt} + + # Apply transforms + transformed = transform(sample) + + # Add batch dimension if needed + if torch.is_tensor(transformed["image"]): + if len(transformed["image"].shape) == 3: + transformed["image"] = transformed["image"].unsqueeze(0) + + return transformed + +def run_inference(model, image_volume, third_ventricle_center, device=None, transform=None): + """ + Run inference on an image volume + + Args: + model: Trained model + image_volume: Input volume as numpy array or path to nifti file + device: torch device to run inference on + transform: Optional custom transform pipeline + + Returns: + dict containing predicted AC and PC coordinates in original image space + """ + if device is None: + device = next(model.parameters()).device + #device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + + # prepend zero to third_ventricle_center + third_ventricle_center = np.concatenate([np.zeros(1), third_ventricle_center]) + + # Preprocess + t_dict = preprocess_volume(image_volume[None], third_ventricle_center, transform) + + + transformed_original = t_dict['image'] + inputs = transformed_original.to(device) + + + inputs = inputs.transpose(0, 1) + batch_size, channels, height, width = inputs.shape + views = [] + for i in range(batch_size - 2): # -2 to ensure we have 3 slices per view + view = inputs[i:i+3] # Take 3 consecutive slices + view = view.reshape(1, 3*channels, height, width) # Reshape to combine slices into channels + views.append(view) + + inputs = torch.cat(views, dim=0) # Stack all views into batch dimension + + + # Run inference + with torch.no_grad(): + outputs = model(inputs) + + # Scale outputs to image size + # img_size = torch.tensor([inputs.shape[2], inputs.shape[3], + # inputs.shape[2], inputs.shape[3]], + # dtype=torch.float32, + # device=device) + outputs = outputs * 64 + + outputs[:, 0] += t_dict['crop_left'] + outputs[:, 1] += t_dict['crop_top'] + outputs[:, 2] += t_dict['crop_left'] + outputs[:, 3] += t_dict['crop_top'] + + + return outputs[:,:2].cpu().numpy(), outputs[:,2:].cpu().numpy(), inputs.cpu().numpy(), (t_dict['crop_left'], t_dict['crop_top']) + +def load_validation_data(path): + import pandas as pd + data = pd.read_csv(path, index_col=0, header=None) + data.columns = ["image", "label", "AC_center_x", "AC_center_y", "AC_center_z", "PC_center_x", "PC_center_y", "PC_center_z"] + + data = data.drop(['15656','5bd8d9b2-e0d3-4a40-b00c-03dfffc5b206'], errors='ignore') + + ac_centers = data[["AC_center_x", "AC_center_y", "AC_center_z"]].values + pc_centers = data[["PC_center_x", "PC_center_y", "PC_center_z"]].values + images = data["image"].values + + label_widths = [] + for label_path in data['label']: + label_img =nib.load(label_path) + + if label_img.shape[0] > 100: + # check which slices have non-zero values + label = label_img.get_fdata() + non_zero_slices = np.any(label > 0, axis=(1,2)) + first_nonzero = np.argmax(non_zero_slices) + last_nonzero = len(non_zero_slices) - np.argmax(non_zero_slices[::-1]) + label_widths.append(last_nonzero - first_nonzero) + else: + label_widths.append(label_img.shape[0]) + + + extended_data = pd.read_csv("/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/network/data/found_labels_with_meta_data_difficult_final.csv", index_col=0) + extended_data = extended_data.loc[data.index] + + third_ventricle_centers = [] + vox_sizes = [] + for aseg_up in extended_data['aseg_up_nocc']: + aseg_up_img = nib.load(aseg_up) + aseg_up_data = aseg_up_img.get_fdata() + + aseg_up_mid = aseg_up_data.shape[0] // 2 + + tv_center = np.mean(np.argwhere(aseg_up_data == 14), axis=0)[1:] + + if np.isnan(tv_center).any(): + import pdb; pdb.set_trace() + + third_ventricle_centers.append(tv_center) + vox_sizes.append(np.prod(aseg_up_img.header.get_zooms()[1])) + + + subj_ids = data.index.values + + return images, ac_centers, pc_centers, label_widths, third_ventricle_centers, vox_sizes, subj_ids + + + +def run_inference_on_slice(model, image_slice, center_pt, debug_output=None): + + # Run inference + pc_coords, ac_coords, inputs, (crop_left, crop_top) = run_inference(model, image_slice, center_pt) + center_pt = np.mean(np.concatenate([ac_coords, pc_coords], axis=0), axis=0) + pc_coords, ac_coords, inputs, (crop_left, crop_top) = run_inference(model, image_slice, center_pt) + pc_coords = np.mean(pc_coords, axis=0) + ac_coords = np.mean(ac_coords, axis=0) + + if debug_output is not None: + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + fig, ax = plt.subplots(1, 1, figsize=(10, 8)) + ax.imshow(image_slice[image_slice.shape[0]//2, :, :], cmap='gray') + # Plot points on all views + ax.scatter(pc_coords[1], pc_coords[0], c='r', marker='x', label='PC') + ax.scatter(ac_coords[1], ac_coords[0], c='b', marker='x', label='AC') + # make a box where the crop is + ax.add_patch(Rectangle((crop_top, crop_left), 64, 64, fill=False, color='r', linewidth=2)) + plt.savefig(debug_output, bbox_inches='tight') + plt.close() + + + return ac_coords, pc_coords + + + + +# TODO: add check if the prediction of first and second round diverges too much + +def run_validation(): + from matplotlib import pyplot as plt + from matplotlib.patches import Rectangle + + # Load model + #model_path = "/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/network/experiments_localization_2/finetune_03_fixweights/checkpoints/best_metric_model.pth" + model_path = '/workspace/weights/localization_weights_acpc.pth' + + model = load_model(model_path) + + # Load a test image slice + #test_img = nib.load("/groups/ag-reuter/projects/corpus_callosum_fornix/label_QC/added_images/48e2d11f/orig_up.mgz") + + val_images, val_ac, val_pc, val_label_widths, val_third_ventricle_centers, val_vox_sizes, val_subj_ids = load_validation_data("/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/network/data/test_joined_labels.csv") + + dist_out = [] + dist_out_dict = {} + uncertainty_out_dict = {} + for img_path, AC_center, PC_center, label_width, third_ventricle_center, vox_size, subj_id in zip(val_images, val_ac, val_pc, val_label_widths, val_third_ventricle_centers, val_vox_sizes, val_subj_ids): + + # if subj_id != '1ca3a723-d981-4bbd-ae97-3f1f03ce5f0e': + # continue + + test_img = nib.load(img_path) + test_slice = test_img.get_fdata() + + #label_width = 13 + + + + # crop to middle 3+-1 (13) slices + test_slice = test_slice[256//2-label_width//2-1:256//2+label_width//2+2] + + # Run inference + start_time = time.time() + ac_coords, pc_coords, inputs, (crop_left, crop_top) = run_inference(model, test_slice, third_ventricle_center) + center_pt = np.mean(np.concatenate([ac_coords, pc_coords], axis=0), axis=0) + ac_coords, pc_coords, inputs, (crop_left, crop_top) = run_inference(model, test_slice, center_pt) + + + inference_time = time.time() - start_time + print(f"Inference took {inference_time:.3f} seconds") + + ac_dist = np.linalg.norm(AC_center[1:] - np.mean(ac_coords, axis=0)) / vox_size + pc_dist = np.linalg.norm(PC_center[1:] - np.mean(pc_coords, axis=0)) / vox_size + # ac_dist = np.linalg.norm(AC_center[1:] - ac_coords[ac_coords.shape[0]//2]) / vox_size + # pc_dist = np.linalg.norm(PC_center[1:] - pc_coords[pc_coords.shape[0]//2]) / vox_size + dist_out.append([ac_dist, pc_dist]) + dist_out_dict[subj_id] = [ac_dist, pc_dist] + + print(f"Distance AC: {ac_dist:.4f}, PC: {pc_dist:.4f}") + + + # fig, ax = plt.subplots(1, 1, figsize=(10, 8)) + # # Original image views + # #ax.imshow(inputs[inputs.shape[0]//2, 1], cmap='gray') + # ax.imshow(test_slice[test_slice.shape[0]//2, :, :], cmap='gray') + # # Plot points on all views + # pc_coords_plot = np.mean(pc_coords, axis=0) + # ac_coords_plot = np.mean(ac_coords, axis=0) + # ax.scatter(PC_center[2], PC_center[1], c='g', marker='o', label='Pred PC', s=2, alpha=0.5) + # ax.scatter(AC_center[2], AC_center[1], c='y', marker='o', label='Pred AC', s=2, alpha=0.5) + # ax.scatter(pc_coords_plot[1], pc_coords_plot[0], c='r', marker='x', label='PC', s=2, alpha=0.5) + # ax.scatter(ac_coords_plot[1], ac_coords_plot[0], c='b', marker='x', label='AC', s=2, alpha=0.5) + + # for i in range(len(pc_coords)): + # ax.scatter(pc_coords[i][1], pc_coords[i][0], c='orange', marker='x', label='PC', s=2, alpha=0.5) + # ax.scatter(ac_coords[i][1], ac_coords[i][0], c='purple', marker='x', label='AC', s=2, alpha=0.5) + + # # make a box where the crop is + # ax.add_patch(Rectangle((crop_top, crop_left), 64, 64, fill=False, color='r', linewidth=2)) + # plt.savefig(f"/workspace/outputs/slice.png", bbox_inches='tight', dpi=500) + # plt.close() + + # print(np.linalg.norm(PC_center[1:] - pc_coords, axis=1)) + # print(np.linalg.norm(AC_center[1:] - ac_coords, axis=1)) + + # fig, ax = plt.subplots(1, 1, figsize=(10, 8)) + # plt.plot(np.linalg.norm(PC_center[1:] - pc_coords, axis=1), color='r') + # plt.plot(np.linalg.norm(AC_center[1:] - ac_coords, axis=1), color='b') + # plt.hlines([np.linalg.norm(PC_center[1:] - pc_coords[pc_coords.shape[0]//2])], 0, len(np.linalg.norm(PC_center[1:] - pc_coords, axis=1)), color='r', linestyle='--') + # plt.hlines([np.linalg.norm(AC_center[1:] - ac_coords[ac_coords.shape[0]//2])], 0, len(np.linalg.norm(AC_center[1:] - ac_coords, axis=1)), color='b', linestyle='--') + # plt.savefig(f"/workspace/outputs/slice_pred_dist.png", bbox_inches='tight') + # plt.close() + + + # print('Uncertainty PC: ', np.linalg.norm(pc_coords - pc_coords[pc_coords.shape[0]//2])) + # print('Uncertainty AC: ', np.linalg.norm(ac_coords - ac_coords[ac_coords.shape[0]//2])) + # uncertainty_out_dict[subj_id] = [np.linalg.norm(pc_coords - pc_coords[pc_coords.shape[0]//2]), np.linalg.norm(ac_coords - ac_coords[ac_coords.shape[0]//2])] + + + + #import pdb; pdb.set_trace() + + + # if len(dist_out_dict) == 3: + # break + + + + import pandas as pd + dist_out_df = pd.DataFrame.from_dict(dist_out_dict, orient='index', columns=['ac_dist', 'pc_dist']) + dist_out_df.to_csv("/workspace/outputs/dist_out_dict.csv") + + uncertainty_out_df = pd.DataFrame.from_dict(uncertainty_out_dict, orient='index', columns=['pc_uncertainty', 'ac_uncertainty']) + uncertainty_out_df.to_csv("/workspace/outputs/uncertainty_localization_out_dict.csv") + + + # Convert numpy array to NIfTI image before saving + #nifti_img_in = nib.Nifti1Image(inputs, affine=test_img.affine, header=test_img.header) + #nifti_orig_slice = nib.Nifti1Image(test_slice[4:-4], affine=test_img.affine, header=test_img.header) + #nib.save(nifti_img_in, "/workspace/outputs/segmentation_input.nii.gz") + #nib.save(nifti_orig_slice, "/workspace/outputs/segmentation_orig.nii.gz") + + + + print(f'Overall error - AC: {np.mean(dist_out, axis=0)[0]:.4f} mm, PC: {np.mean(dist_out, axis=0)[1]:.4f} mm') + + + # validation set, middle 2x AC: 0.7648 mm, PC: 0.8181 mm + # validation set, mean 2x AC: 0.7638 mm, PC: 0.8404 mm --- chose mean + + # test set (mean 2x) AC: 0.9004 mm, PC: 0.9482 mm + + # diificult set (mean 2x): AC: 0.9179 mm, PC: 1.3477 mm + + +# Example usage: +if __name__ == "__main__": + run_validation() \ No newline at end of file diff --git a/CorpusCallosum/registration/mapping_helpers.py b/CorpusCallosum/registration/mapping_helpers.py new file mode 100644 index 00000000..ed209d81 --- /dev/null +++ b/CorpusCallosum/registration/mapping_helpers.py @@ -0,0 +1,285 @@ +from pathlib import Path +import numpy as np +import nibabel as nib +import matplotlib.pyplot as plt +from scipy.ndimage import affine_transform + + +def make_midplane_affine(orig_affine, slices_to_analyze=1, offset=4): + """ + Creates an affine transformation matrix for midplane slices. + + Args: + orig_affine: Original image affine matrix + slices_to_analyze: Number of slices to analyze around midplane (default=1) + offset: Additional offset in x direction (default=4) + + Returns: + seg_affine: Affine matrix for midplane slices + """ + # Create translation matrix to center on midplane + orig_to_seg = np.eye(4) + orig_to_seg[0, 3] = -256//2 + slices_to_analyze//2 + offset + + # Combine with original affine + seg_affine = orig_affine @ np.linalg.inv(orig_to_seg) + + return seg_affine + + +def correct_nodding(ac_pt, pc_pt): + """ + Calculates rotation matrix to correct for head nodding based on AC-PC line orientation. + + Args: + ac_pt: Coordinates of the anterior commissure point + pc_pt: Coordinates of the posterior commissure point + + Returns: + rotation_matrix: 3x3 rotation matrix to align AC-PC line with posterior direction + """ + ac_pc_vec = pc_pt - ac_pt + ac_pc_dist = np.linalg.norm(ac_pc_vec) + + posterior_vector = np.array([0, -ac_pc_dist]) + + # get angle between ac_pc_vec and posterior_vector + dot_product = np.dot(ac_pc_vec, posterior_vector) + norms_product = np.linalg.norm(ac_pc_vec) * np.linalg.norm(posterior_vector) + theta = np.arccos(dot_product / norms_product) + + # Determine the sign of the angle using cross product + cross_product = np.cross(ac_pc_vec, posterior_vector) + if cross_product < 0: + theta = -theta + + # create rotation matrix for theta + rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1]]) + + # plot vector ac_pc_vec and posterior_vector + # fig, ax = plt.subplots() + # ax.quiver(0, 0, ac_pc_vec[0], ac_pc_vec[1], color='red', label='ac_pc_vec') + # ax.quiver(0, 0, posterior_vector[0], posterior_vector[1], color='blue', label='posterior_vector') + # ax.legend() + # plt.show() + + return rotation_matrix + + + +def apply_transform_to_pt(pts, T, inv=False): + """ + Applies an homoegenous 4x4 transformation matrix to a point. + + Args: + pts: Point coordinates to transform + T: Transformation matrix + inv: If True, applies inverse of transformation (default=False) + + Returns: + Transformed point coordinates + """ + if inv: + T = T.copy() + T = np.linalg.inv(T) + + if pts.ndim == 1: + return (T @ np.hstack((pts,1)))[:3] + else: + return (T @ np.concatenate([pts,np.ones((1,pts.shape[1]))]))[:3] + +def get_mapping_to_standard_space(orig, ac_coords_3d, pc_coords_3d, orig_fsaverage_vox2vox, output_dir): + """ + Maps an image to standard space using AC-PC alignment. + + Args: + orig: Original image + ac_coords_3d: 3D coordinates of anterior commissure + pc_coords_3d: 3D coordinates of posterior commissure + orig_fsaverage_vox2vox: Original to fsaverage space transformation matrix + output_dir: Directory for output files + + Returns: + tuple: (transformation matrix, AC coords standardized, PC coords standardized, + AC coords original, PC coords original) + """ + image_center = np.array(orig.shape) / 2 + + # correct nodding + nod_correct_2d = correct_nodding(ac_coords_3d[1:3], pc_coords_3d[1:3]) + + # convert 2D nodding correction to 3D transformation matrix + nod_correct_3d = np.eye(4) + nod_correct_3d[1:3,1:3] = nod_correct_2d[:2,:2] # Copy rotation part to y,z axes + nod_correct_3d[1:3,3] = nod_correct_2d[:2,2] # Copy translation part to y,z axes (usually no translation) + + + ac_coords_after_nodding = apply_transform_to_pt(ac_coords_3d, nod_correct_3d, inv=False) + pc_coords_after_nodding = apply_transform_to_pt(pc_coords_3d, nod_correct_3d, inv=False) + + ac_to_center_translation = np.eye(4) + ac_to_center_translation[0,3] = image_center[0] - ac_coords_after_nodding[0] + ac_to_center_translation[1,3] = image_center[1] - ac_coords_after_nodding[1] + ac_to_center_translation[2,3] = image_center[2] - ac_coords_after_nodding[2] + + # correct nodding + ac_coords_standardized = apply_transform_to_pt(ac_coords_after_nodding, ac_to_center_translation, inv=False) + pc_coords_standardized = apply_transform_to_pt(pc_coords_after_nodding, ac_to_center_translation, inv=False) + + standardized_to_orig_vox2vox = np.linalg.inv(orig_fsaverage_vox2vox) @ np.linalg.inv(nod_correct_3d) @ np.linalg.inv(ac_to_center_translation) + + # calculate ac & pc in space of mri input image + ac_coords_orig = apply_transform_to_pt(ac_coords_standardized, standardized_to_orig_vox2vox, inv=False) + pc_coords_orig = apply_transform_to_pt(pc_coords_standardized, standardized_to_orig_vox2vox, inv=False) + + return standardized_to_orig_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig + + +def apply_transform_and_map_volume(volume, transform, affine, header, output_path=None, order=3, output_size=None): + """ + Applies transformation to a volume and saves the result. + + Args: + volume: Input volume data + transform: Transformation matrix to apply + affine: Affine matrix for the output image + header: Header for the output image + output_path: Path to save transformed volume + + Returns: + transformed: Transformed volume data + """ + + if output_size is None: + output_size = np.array(volume.shape) + transformed = affine_transform(volume.astype(np.float32), np.linalg.inv(transform), output_shape=output_size, order=order) + if output_path is not None: + nib.save(nib.MGHImage(transformed, affine, header), output_path) + return transformed + + +def make_affine(simpleITKImage): + """ + Creates an affine transformation matrix from a SimpleITK image. + + Args: + simpleITKImage: Input SimpleITK image + + Returns: + affine: 4x4 affine transformation matrix in RAS coordinates + """ + # get affine transform in LPS + c = [simpleITKImage.TransformContinuousIndexToPhysicalPoint(p) + for p in ((1, 0, 0), + (0, 1, 0), + (0, 0, 1), + (0, 0, 0))] + c = np.array(c) + affine = np.concatenate([ + np.concatenate([c[0:3] - c[3:], c[3:]], axis=0), + [[0.], [0.], [0.], [1.]] + ], axis=1) + affine = np.transpose(affine) + # convert to RAS to match nibabel + affine = np.matmul(np.diag([-1., -1., 1., 1.]), affine) + return affine + + + + +def map_softlabels_to_orig(outputs_soft, orig_fsaverage_vox2vox, orig, slices_to_analyze, orig_space_segmentation_path = None, fsaverage_middle=128): + """ + Maps soft labels back to original image space and applies post-processing. + + # TODO: this could by padding after the transform + + Args: + outputs_soft: Soft label predictions + orig_fsaverage_vox2vox: Original to fsaverage space transformation + orig: Original image + slices_to_analyze: Number of slices to analyze + + Returns: + segmentation_orig_space: Final segmentation in original image space + """ + # map softlabels to original image + softlabels_transformed = [] + for i in range(outputs_soft.shape[-1]): + + # pad to original image size + outputs_soft_padded = np.zeros(orig.shape) + outputs_soft_padded[fsaverage_middle-slices_to_analyze//2:fsaverage_middle+slices_to_analyze//2+1] = outputs_soft[...,i] + + s = affine_transform( + outputs_soft_padded, + orig_fsaverage_vox2vox, + output_shape=orig.shape, + order=1, + cval=1.0 if i == 0 else 0.0 + ) + softlabels_transformed.append(s) + + softlabels_orig_space = np.stack(softlabels_transformed, axis=-1) + + # nib.save(nib.MGHImage(outputs_soft, seg_affine, transformed_img.header), Path(output_dir) / "softlabels_seg_space.mgz") + # nib.save(nib.MGHImage(softlabels_orig_space, orig.affine, orig.header), Path(output_dir) / "softlabels_orig_space.mgz") + + # apply softmax to softlabels_orig_space + softlabels_orig_space = np.exp(softlabels_orig_space) / np.sum(np.exp(softlabels_orig_space), axis=-1, keepdims=True) + + segmentation_orig_space = np.argmax(softlabels_orig_space, axis=-1) + segmentation_orig_space = np.where(segmentation_orig_space == 1, 192, segmentation_orig_space) + segmentation_orig_space = np.where(segmentation_orig_space == 2, 250, segmentation_orig_space) + + if orig_space_segmentation_path is not None: + nib.save(nib.MGHImage(segmentation_orig_space, orig.affine, orig.header), orig_space_segmentation_path) + + return segmentation_orig_space + +def interpolate_midplane(orig, orig_fsaverage_vox2vox, slices_to_analyze): + """ + Interpolates image data at the midplane using a grid of points. + + Args: + orig: Original image + orig_fsaverage_vox2vox: Original to fsaverage space transformation + slices_to_analyze: Number of slices to analyze + + Returns: + transformed: Interpolated image data at midplane + """ + + #slice_thickness = 9+slices_to_analyze-1 + # make grid of 9 slices in the fsaverage middle (cube from 123.5,0.5,0.5 to 132.5,255.5,255.5 (incudling end points, 1mm spacing)) + x_coords = np.linspace(124-slices_to_analyze//2, 132+slices_to_analyze//2, 9+(slices_to_analyze-1), endpoint=True) # 9 points from 123.5 to 132.5 + #x_coords = np.linspace(orig.shape[0]//2-slice_thickness//2, orig.shape[0]//2+slice_thickness//2, slice_thickness, endpoint=True) + y_coords = np.linspace(0, orig.shape[1]-1, orig.shape[1], endpoint=True) # 255 points from 0.5 to 255.5 + z_coords = np.linspace(0, orig.shape[2]-1, orig.shape[2], endpoint=True) # 255 points from 0.5 to 255.5 + X, Y, Z = np.meshgrid(x_coords, y_coords, z_coords, indexing='ij') + + # Stack coordinates and add homogeneous coordinate + grid_fsaverage = np.stack([ + X.ravel(), + Y.ravel(), + Z.ravel(), + np.ones(X.size) + ]) + + # move grid to orig space by applying transform + grid_orig = np.linalg.inv(orig_fsaverage_vox2vox) @ grid_fsaverage + + # interpolate grid on orig image + from scipy.ndimage import map_coordinates + transformed = map_coordinates( + orig.get_fdata(), + grid_orig[0:3,:], # use only x,y,z coordinates (drop homogeneous coordinate) + order=2, + mode='constant', + cval=0, + prefilter=True + ).reshape(len(x_coords), len(y_coords), len(z_coords)) + + return transformed + diff --git a/CorpusCallosum/segmentation/segmentation_inference.py b/CorpusCallosum/segmentation/segmentation_inference.py new file mode 100644 index 00000000..21316d12 --- /dev/null +++ b/CorpusCallosum/segmentation/segmentation_inference.py @@ -0,0 +1,417 @@ +import time +import torch +import numpy as np +import nibabel as nib + +from monai import transforms +from monai.metrics import DiceMetric, HausdorffDistanceMetric + +from FastSurferCNN.models.networks import FastSurferVINN +from transforms.segmentation_transforms import CropAroundACPC, UncropAroundACPC + + +def load_model(checkpoint_path, device=None): + """ + Load the trained model from checkpoint + + Args: + checkpoint_path: Path to model checkpoint + device: torch device to load model to (defaults to CUDA if available) + + Returns: + model: Loaded model + """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + + params = { + "num_classes": 3, + "num_filters": 71, + "num_filters_interpol": 32, + "num_channels": 9, + "kernel_h": 3, + "kernel_w": 3, + "kernel_c": 1, + "stride_conv": 1, + "stride_pool": 2, + "pool": 2, + "height": 128, + "width": 128, + "base_res": 1.0, + "interpolation_mode": "bilinear", + "crop_position": "top_left", + "out_tensor_width": 320, + "out_tensor_height": 320, + } + model = FastSurferVINN(params) + + #model = torch.load(checkpoint_path, map_location=device, weights_only=False) + weights = torch.load(checkpoint_path, weights_only=True, map_location=device) + model.load_state_dict(weights) + model.eval() + model.to(device) + return model + +def run_inference(model, image_slice, AC_center, PC_center, voxel_size, device=None, transform=None): + """ + Run inference on a single image slice + + Args: + model: Trained model + image_slice: Input image as numpy array + device: torch device to run inference on + transform: Optional custom transform pipeline + + Returns: + dict containing: + segmentation: Segmentation map if model produces segmentation + landmarks: Predicted landmarks if model produces localization + """ + orig_shape = image_slice.shape + + if device is None: + #device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = next(model.parameters()).device + + crop_around_acpc = lambda img, ac, pc, vox_size: CropAroundACPC(keys=['image'], padding_mm=35, random_translate=0)({'image': img, 'AC_center': ac, 'PC_center': pc, 'res': vox_size}) + + # Preprocess slice + inputs = torch.from_numpy(image_slice[:,None,:256,:256]) # artifact from training script + crop_dict = crop_around_acpc(inputs, AC_center, PC_center, voxel_size) + inputs, to_pad = crop_dict['image'], crop_dict['to_pad'] + inputs = transforms.utils.rescale_array(inputs, 0, 1, dtype=np.float32) + inputs = inputs.to(device) + + post_trans = transforms.Compose( + [transforms.Activations(softmax=True), transforms.AsDiscrete(argmax=True, to_onehot=3)] + ) + + # split into slices with 9 channels each + # Generate views with sliding window of 9 slices + batch_size, channels, height, width = inputs.shape + views = [] + for i in range(batch_size - 8): # -8 to ensure we have 9 slices per view + view = inputs[i:i+9] # Take 9 consecutive slices + view = view.reshape(1, 9*channels, height, width) # Reshape to combine slices into channels + views.append(view) + + inputs = torch.cat(views, dim=0) # Stack all views into batch dimension + + # Post-process outputs + with torch.no_grad(): + scale_factors = torch.ones((inputs.shape[0], 2), device=device) * (1 / voxel_size) + + outputs = model(inputs, scale_factor=scale_factors) + + # average the outputs along the batch dimension + outputs_avg = torch.mean(outputs, dim=0).unsqueeze(0) + + outputs_soft = outputs.cpu().numpy() #transforms.Activations(softmax=True)(outputs) # non_discrete outputs + outputs = torch.stack([post_trans(i) for i in outputs]) + outputs_avg = torch.stack([post_trans(i) for i in outputs_avg]) + + pad_left, pad_right, pad_top, pad_bottom = to_pad + # Pad back to original size + outputs = np.pad(outputs, ((0,0), (0,0), (pad_left.item(), pad_right.item()), (pad_top.item(), pad_bottom.item())), mode='constant', constant_values=0) + outputs_avg = np.pad(outputs_avg, ((0,0), (0,0), (pad_left.item(), pad_right.item()), (pad_top.item(), pad_bottom.item())), mode='constant', constant_values=0) + outputs_soft = np.pad(outputs_soft, ((0,0), (0,0), (pad_left.item(), pad_right.item()), (pad_top.item(), pad_bottom.item())), mode='constant', constant_values=0) + + # restore original shape + if orig_shape[-2:] != outputs.shape[-2:]: + new_outputs = np.zeros((outputs.shape[0], outputs.shape[1], orig_shape[-2], orig_shape[-1])) + new_outputs[:,:,:256,:256] = outputs + outputs = new_outputs + + new_outputs_avg = np.zeros((outputs_avg.shape[0], outputs_avg.shape[1], orig_shape[-2], orig_shape[-1])) + new_outputs_avg[:,:,:256,:256] = outputs_avg + outputs_avg = new_outputs_avg + + new_outputs_soft = np.zeros((outputs_soft.shape[0], outputs_soft.shape[1], orig_shape[-2], orig_shape[-1]), dtype=np.float32) + new_outputs_soft[:,:,:256,:256] = outputs_soft + outputs_soft = new_outputs_soft + + return outputs.transpose(0,2,3,1), inputs.cpu().numpy().transpose(0,2,3,1), outputs_avg.transpose(0,2,3,1), outputs_soft.transpose(0,2,3,1) + +# TODO: load validation data and run inference on it to confirm correct processing + + +def load_validation_data(path): + import pandas as pd + data = pd.read_csv(path, index_col=0, header=None) + data.columns = ["image", "label", "AC_center_x", "AC_center_y", "AC_center_z", "PC_center_x", "PC_center_y", "PC_center_z"] + + ac_centers = data[["AC_center_x", "AC_center_y", "AC_center_z"]].values + pc_centers = data[["PC_center_x", "PC_center_y", "PC_center_z"]].values + images = data["image"].values + labels = data["label"].values + subj_ids = data.index.values.tolist() + + label_widths = [] + for label_path in data['label']: + label_img =nib.load(label_path) + + if label_img.shape[0] > 100: + # check which slices have non-zero values + label = label_img.get_fdata() + non_zero_slices = np.any(label > 0, axis=(1,2)) + first_nonzero = np.argmax(non_zero_slices) + last_nonzero = len(non_zero_slices) - np.argmax(non_zero_slices[::-1]) + label_widths.append(last_nonzero - first_nonzero) + else: + label_widths.append(label_img.shape[0]) + + + + return images, ac_centers, pc_centers, label_widths, labels, subj_ids + + +def one_hot_to_label(one_hot, label_ids=[0,192,250]): + label = np.argmax(one_hot, axis=3) + if label_ids is not None: + label = np.where(label == 0, label_ids[0], label) + label = np.where(label == 1, label_ids[1], label) + label = np.where(label == 2, label_ids[2], label) + + return label + + +# TODO: add heuristic that removes islands that are far away + + + +def run_inference_on_slice(model, test_slice, AC_center, PC_center, voxel_size): + + # add zero in front of AC_center and PC_center + AC_center = np.concatenate([np.zeros(1), AC_center]) + PC_center = np.concatenate([np.zeros(1), PC_center]) + + results, inputs, outputs_avg, outputs_soft = run_inference(model, test_slice, AC_center, PC_center, voxel_size) + results = one_hot_to_label(results) + + return results, inputs, outputs_avg, outputs_soft + + + +def remove_small_clusters(label_data, min_cluster_size=100): + """ + Removes small clusters of connected components from a label image. + + Args: + label_data: numpy array containing the label data + min_cluster_size: minimum size of clusters to keep (default: 100) + + Returns: + cleaned_label: numpy array with small clusters removed + """ + from scipy.ndimage import label as ndlabel + + + list_of_cleaned_labels = [] + + for label_id in range(label_data.shape[1]-1): + + # Create a binary mask of the label + binary_mask = label_data[:,label_id+1] > 0 + + + # Label the connected components + labeled_array, num_features = ndlabel(binary_mask) + + # Create a mask for small clusters + small_clusters_mask = np.zeros_like(binary_mask, dtype=bool) + for i in range(1, num_features + 1): + small_cluster = (labeled_array == i) + if np.sum(small_cluster) < min_cluster_size: + small_clusters_mask |= small_cluster + + # Remove small clusters from the original label + cleaned_label = label_data[:,label_id+1].copy() + cleaned_label[small_clusters_mask] = 0 + list_of_cleaned_labels.append(cleaned_label) + + + # plot binary mask + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots(2,len(binary_mask)) + # for i in range(len(binary_mask)): + # ax[0,i].imshow(binary_mask[i]) + # ax[1,i].imshow(cleaned_label[i]) + # plt.show() + + return np.stack([label_data[:,0]]+list_of_cleaned_labels, axis=1) + + + +def run_validation(): + + # Load model + + model_path = "/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/cc_pipeline/weights/segmentation_weights_cc_fn.pth" + # /groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/network/experiments/CCFN_softmax01/checkpoints/best_metric_model.pth + + model = load_model(model_path) + + # Load a test image slice + #test_img = nib.load("/groups/ag-reuter/projects/corpus_callosum_fornix/label_QC/added_images/48e2d11f/orig_up.mgz") + + val_images, val_ac, val_pc, val_label_widths, val_labels, val_subj_ids = load_validation_data("/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/network/data/difficult_joined_labels.csv") + + dice_out = [] + dice_out_single_slice = [] + dice_out_dict = {} + dice_out_single_slice_dict = {} + + # Initialize Hausdorff distance metric + hd_out = [] + hd_out_single_slice = [] + hd_out_dict = {} + hd_out_single_slice_dict = {} + + for img_path, AC_center, PC_center, label_width, label_path, subj_id in zip(val_images, val_ac, val_pc, val_label_widths, val_labels, val_subj_ids): + + # if subj_id != "abf05659": + # continue + + + label_width = 5 + + test_img = nib.load(img_path) + test_slice = test_img.get_fdata() + + # crop to middle 9+5-1 (13) slices + test_slice = test_slice[256//2-label_width//2-4:256//2+label_width//2+5] + + + # Run inference + start_time = time.time() + results, inputs, outputs_avg, outputs_soft = run_inference(model, test_slice, AC_center, PC_center, voxel_size=test_img.header.get_zooms()[0]) + inference_time = time.time() - start_time + print(f"Inference took {inference_time:.3f} seconds") + + label_img = nib.load(label_path) + label = label_img.get_fdata() + + # calculate dice score + dice_metric = DiceMetric(include_background=False, reduction="mean") + hd_metric = HausdorffDistanceMetric(include_background=False, percentile=95.0, reduction="mean") + + # Convert label to one-hot format + label_tensor = torch.from_numpy(label) + + if label_tensor.shape[0] > 100: + # select non-zero slices + label_tensor = label_tensor[label_tensor.any(axis=(1,2))] + + # crop to label width + label_tensor = label_tensor[label_tensor.shape[0]//2-label_width//2:label_tensor.shape[0]//2+label_width//2+1] + + # map to 0,1,2 + ids = np.unique(label) + label_tensor = torch.where(label_tensor == ids[0], 0, label_tensor) + label_tensor = torch.where(label_tensor == ids[1], 1, label_tensor) + label_tensor = torch.where(label_tensor == ids[2], 2, label_tensor) + + label_onehot = torch.nn.functional.one_hot(label_tensor.long(), num_classes=3) # Convert to one-hot with 3 classes + label_onehot = label_onehot.permute(0, 3, 1, 2) # Move class dimension to second position (B,C, H, W) + #label_onehot = label_onehot[:,:,:256,:256] + + # Reshape results to (B, C, H, W) + results_tensor = torch.from_numpy(results) + results_tensor = results_tensor.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) + + # Remove small clusters + results_tensor = remove_small_clusters(results_tensor.numpy(), min_cluster_size=100) + results_tensor = torch.from_numpy(results_tensor) + + + + # Calculate Dice score + dice_score = dice_metric(results_tensor, label_onehot) + midslice = results_tensor.shape[0]//2 + dice_single_slice = dice_metric(results_tensor[None,midslice], label_onehot[None,midslice]) + + # Calculate Hausdorff distance + # Get physical spacing from the image header for accurate distance calculation + spacing = test_img.header.get_zooms()[:3] # Get voxel dimensions in mm + if len(spacing) == 3: + # Use only in-plane spacing for 2D slices + spacing_tensor = torch.tensor([spacing[1], spacing[2]], dtype=torch.float32) + else: + spacing_tensor = torch.tensor(spacing, dtype=torch.float32) + + hd_score = hd_metric(results_tensor, label_onehot, spacing=spacing_tensor.numpy().tolist()) + hd_single_slice = hd_metric(results_tensor[None,midslice], label_onehot[None,midslice], spacing=spacing_tensor.numpy().tolist()) + + # Store results + dice_out.append(dice_score.mean(axis=0).numpy().tolist()) + dice_out_single_slice.append(dice_single_slice.numpy().tolist()) + dice_out_dict[subj_id] = dice_score.mean(axis=0).numpy().tolist() + dice_out_single_slice_dict[subj_id] = dice_single_slice.numpy()[0].tolist() + + hd_out.append(hd_score.mean(axis=0).numpy().tolist()) + hd_out_single_slice.append(hd_single_slice.numpy().tolist()) + hd_out_dict[subj_id] = hd_score.mean(axis=0).numpy().tolist() + hd_out_single_slice_dict[subj_id] = hd_single_slice.numpy()[0].tolist() + + print(f"Subject: {subj_id}") + print(f"Dice mean: {[f'{x:.3f}' for x in dice_score.mean(axis=0).numpy().tolist()]}") + print(f"HD95 mean: {[f'{x:.3f}' for x in hd_score.mean(axis=0).numpy().tolist()]} mm") + + + + + # Convert numpy array to NIfTI image before saving + nifti_img_out = nib.Nifti1Image(results, affine=test_img.affine, header=test_img.header) + nifti_img_in = nib.Nifti1Image(inputs, affine=test_img.affine, header=test_img.header) + nifti_orig_slice = nib.Nifti1Image(test_slice[4:-4], affine=test_img.affine, header=test_img.header) + nifti_avg_slice = nib.Nifti1Image(outputs_avg, affine=test_img.affine, header=test_img.header) + nifti_label = nib.Nifti1Image(label, affine=test_img.affine, header=test_img.header) + nifti_final_out = nib.Nifti1Image(one_hot_to_label(results), affine=test_img.affine, header=test_img.header) + nib.save(nifti_img_in, "/workspace/outputs/segmentation_input.nii.gz") + nib.save(nifti_img_out, "/workspace/outputs/segmentation.nii.gz") + nib.save(nifti_orig_slice, "/workspace/outputs/segmentation_orig.nii.gz") + nib.save(nifti_avg_slice, "/workspace/outputs/segmentation_avg.nii.gz") + nib.save(nifti_label, "/workspace/outputs/segmentation_label.nii.gz") + nib.save(nifti_final_out, "/workspace/outputs/segmentation_final.nii.gz") + import shutil + shutil.copy(img_path, "/workspace/outputs/segmentation_orig.mgz") + shutil.copy(label_path, "/workspace/outputs/segmentation_label.mgz") + + + + + + + print(f'Overall Validation Dice: {[f"{x:.3f}" for x in np.mean(dice_out, axis=0).tolist()]}') + print(f'Overall Validation HD95: {[f"{x:.3f}" for x in np.mean(hd_out, axis=0).tolist()]} mm') + + import pandas as pd + # Save Dice scores + dice_out_df = pd.DataFrame.from_dict(dice_out_dict, orient='index', columns=["CC", "FN"]) + dice_single_slice_df = pd.DataFrame.from_dict(dice_out_single_slice_dict, orient='index', columns=["CC", "FN"]) + dice_out_df.to_csv("/workspace/outputs/dice_out.csv") + dice_single_slice_df.to_csv("/workspace/outputs/dice_single_slice.csv") + + # Save Hausdorff distances + hd_out_df = pd.DataFrame.from_dict(hd_out_dict, orient='index', columns=["CC", "FN"]) + hd_single_slice_df = pd.DataFrame.from_dict(hd_out_single_slice_dict, orient='index', columns=["CC", "FN"]) + hd_out_df.to_csv("/workspace/outputs/hd_out.csv") + hd_single_slice_df.to_csv("/workspace/outputs/hd_single_slice.csv") + + # Create a combined metrics dataframe + combined_metrics = pd.DataFrame() + combined_metrics['Dice_CC'] = dice_out_df['CC'] + combined_metrics['Dice_FN'] = dice_out_df['FN'] + combined_metrics['HD95_CC'] = hd_out_df['CC'] + combined_metrics['HD95_FN'] = hd_out_df['FN'] + combined_metrics.to_csv("/workspace/outputs/combined_metrics.csv") + + # Testset: Overall Dice: ['0.957', '0.829'] HD95: ['1.018', '2.799'] + # Testset only 5 slices: Overall Validation Dice: ['0.957', '0.831'] HD95: ['1.025', '2.318'] + # Difficultset: Overall Validation Dice: ['0.944', '0.785'] HD95: ['1.189', '4.080'] + # Difficultset only 5 slices: Overall Validation Dice: ['0.946', '0.784'] HD95: ['1.155', '4.101'] + + +if __name__ == "__main__": + run_validation() \ No newline at end of file diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py new file mode 100644 index 00000000..ea47b9df --- /dev/null +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -0,0 +1,132 @@ +import numpy as np +from scipy import ndimage +from skimage.measure import label + +from CorpusCallosum.data.constants import * + + + +def get_cc_volume(desired_width_mm: int, cc_mask: np.ndarray, voxel_size: tuple[float, float, float]) -> float: + """Calculate the volume of the corpus callosum in cubic millimeters. + + This function calculates the volume of the corpus callosum (CC) in cubic millimeters. + If the CC width is larger than desired_width_mm, the voxels on the edges are calculated as + partial volumes to achieve the desired width. + + Args: + desired_width_mm (int): Desired width of the CC in millimeters + cc_mask (np.ndarray): Binary mask of the corpus callosum + voxel_size (tuple[float, float, float]): Voxel size in millimeters (x, y, z) + + Returns: + float: Volume of the CC in cubic millimeters + + Raises: + ValueError: If CC width is smaller than desired width + AssertionError: If CC mask doesn't have odd number of voxels in x dimension + """ + assert cc_mask.shape[0] % 2 == 1, "CC mask must have odd number of voxels in x dimension" + + + # Calculate voxel volume + voxel_volume = np.prod(voxel_size) + + # Get width of CC mask in voxels by finding the extent in x dimension + width_vox = np.sum(np.any(cc_mask, axis=(1,2))) + + # we are in LIA, so 0 is L/R resolution + width_mm = width_vox * voxel_size[0] + + if width_mm == desired_width_mm: + return np.sum(cc_mask) * voxel_volume + elif width_mm > desired_width_mm: + # remainder on the left/right side of the CC mask + desired_width_vox = desired_width_mm / voxel_size[0] + fraction_of_voxel_at_edge = (desired_width_vox % 1) / 2 + + if fraction_of_voxel_at_edge > 0: + desired_width_vox = int(np.floor(desired_width_vox) + 1) + desired_width_vox = desired_width_vox + 1 if desired_width_vox % 2 == 0 else desired_width_vox + + assert cc_mask.shape[0] == desired_width_vox, f"CC mask should have {desired_width_vox} voxels, but has {cc_mask.shape[0]}" + + + + left_partial_volume = np.sum(cc_mask[0]) * voxel_volume * fraction_of_voxel_at_edge + right_partial_volume = np.sum(cc_mask[-1]) * voxel_volume * fraction_of_voxel_at_edge + center_volume = np.sum(cc_mask[1:-1]) * voxel_volume + return left_partial_volume + right_partial_volume + center_volume + else: + raise ValueError(f"Width of CC segmentation is smaller than desired width: {width_mm} < {desired_width_mm}") + + + +def get_largest_cc(seg_arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Get largest connected component from a binary segmentation array. + + This function takes a binary segmentation array, dilates it, finds connected components, + and returns the largest component (excluding background) along with its mask. + + Args: + seg_arr (np.ndarray): Input binary segmentation array + + Returns: + tuple: A tuple containing: + - clean_seg (np.ndarray): Segmentation array with only the largest connected component + - largest_cc (np.ndarray): Binary mask of the largest connected component + """ + # generate dilatation structure + struct1 = ndimage.generate_binary_structure(3, 3) + # Dilate prediction + mask = ndimage.binary_dilation(seg_arr, structure=struct1, iterations=1, ).astype(np.uint8) + # Get connected components + labels_cc = label(mask, connectivity=3, background=0) + # Get componnets count + bincount = np.bincount(labels_cc.flat) + # Get background label, assumption that background is the biggest connected component + background = np.argmax(bincount) + bincount[background] = -1 + # Get largest connected component + largest_cc = labels_cc == np.argmax(bincount) + # Apply mask + clean_seg = seg_arr * largest_cc + + return clean_seg,largest_cc + +def clean_cc_segmentation(seg_arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Clean corpus callosum segmentation by removing non-connected components. + + This function processes a segmentation array to clean up the corpus callosum (CC) + by removing non-connected components. It first isolates the CC (label 192), + removes non-connected components, then adds the fornix (label 250), and + finally removes non-connected components from the combined CC and fornix. + + Args: + seg_arr (np.ndarray): Input segmentation array with CC (192) and fornix (250) labels + + Returns: + tuple: A tuple containing: + - clean_seg (np.ndarray): Cleaned segmentation array with only the largest + connected component of CC and fornix + - mask (np.ndarray): Binary mask of the largest connected component + """ + #Remove non connected components from the CC alone + clean_seg = np.zeros_like(seg_arr) + clean_seg[seg_arr == CC_LABEL] = CC_LABEL + clean_seg,_ = get_largest_cc(clean_seg) + + #Add fornix to the CC labels + clean_seg[seg_arr == FORNIX_LABEL] = FORNIX_LABEL + + #Remove non connected components from CC & Fornix + clean_seg, mask = get_largest_cc(clean_seg) + + unique_labels = np.unique(clean_seg) + + if 250 not in unique_labels: + clean_seg[seg_arr == 250] = 250 + mask [seg_arr == 250] = True + if 192 not in unique_labels: + clean_seg[seg_arr == 192] = 192 + mask[seg_arr == 192] = True + return clean_seg, mask diff --git a/CorpusCallosum/shape/cc_endpoint_heuristic.py b/CorpusCallosum/shape/cc_endpoint_heuristic.py new file mode 100644 index 00000000..43963a9e --- /dev/null +++ b/CorpusCallosum/shape/cc_endpoint_heuristic.py @@ -0,0 +1,182 @@ +import nibabel as nib +import numpy as np +import skimage.measure +import scipy.ndimage +import pandas as pd +from shape.resample_poly import iterative_resample_polygon + +def get_endpoints(cc_mask, AC_2d, PC_2d, resolution, return_coordinates=True, contour_smoothing=1.0): + """ + Determines endpoints of CC by finding the point in the contour closest to the anterior and posterior commisure (with some offsets) + + NOTE: Expects LIA orientation + """ + image_size = cc_mask.shape + + # Calculate angle between AC-PC line and horizontal using numpy + ac_pc_vector = PC_2d - AC_2d + horizontal_vector = np.array([0, -20]) + # Calculate angle using dot product formula: cos(theta) = (a·b)/(|a||b|) + dot_product = np.dot(ac_pc_vector, horizontal_vector) + norms = np.linalg.norm(ac_pc_vector) * np.linalg.norm(horizontal_vector) + theta = np.arccos(dot_product / norms) + + + # Convert symbolic theta to float and convert from radians to degrees + theta_degrees = theta * 180 / np.pi + rotated_cc_mask = scipy.ndimage.rotate(cc_mask, -theta_degrees, order=0, reshape=False) + + + # rotate points around center + origin_point = np.array([image_size[0]//2, image_size[1]//2]) + + # Create rotation matrix for -theta + rot_matrix = np.array([[np.cos(-theta), -np.sin(-theta)], + [np.sin(-theta), np.cos(-theta)]]) + + # Translate points to origin, rotate, then translate back + pc_centered = PC_2d - origin_point + ac_centered = AC_2d - origin_point + + rotated_PC_2d = (rot_matrix @ pc_centered) + origin_point + rotated_AC_2d = (rot_matrix @ ac_centered) + origin_point + + # get contour of CC + gaussian_cc_mask = scipy.ndimage.gaussian_filter(rotated_cc_mask.astype(float), sigma=contour_smoothing) + #gaussian_cc_mask = scipy.ndimage.gaussian_filter(gaussian_cc_mask, sigma=1.0) + contour = skimage.measure.find_contours(gaussian_cc_mask, level=0.5)[0].T + + contour = iterative_resample_polygon(contour.T, 701).T + contour = contour[:,:-1] + + rotated_AC_2d = np.array(rotated_AC_2d).astype(float) + rotated_PC_2d = np.array(rotated_PC_2d).astype(float) + + # move posterior commisure 5 mm posterior + rotated_PC_2d = rotated_PC_2d + np.array([10 * resolution, -5 * resolution]) + + # move anterior commisure 1.5 mm anterior + rotated_AC_2d = rotated_AC_2d + np.array([0, 5 * resolution]) + + # find point in contour closest to AC + AC_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_AC_2d[:,None], axis=0)) + + # find point in contour closest to PC + PC_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_PC_2d[:,None], axis=0)) + + # rotate startpoints to original orientation + # Create rotation matrix + rot_matrix = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + + # rotate contour to original orientation + contour_rotated = np.zeros_like(contour) + + origin_point = np.array(origin_point).astype(float) + # Create rotation matrix + rot_matrix = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + + # Translate points to origin, rotate, then translate back + contour_centered = contour - origin_point[:,None] + contour_rotated = (rot_matrix @ contour_centered) + origin_point[:,None] + + + if return_coordinates: + AC_contour_point = contour[:,AC_startpoint_idx] + PC_contour_point = contour[:,PC_startpoint_idx] + + # Translate points to origin, rotate, then translate back + ac_centered = AC_contour_point - origin_point + pc_centered = PC_contour_point - origin_point + + start_point_A = (rot_matrix @ ac_centered) + origin_point + start_point_P = (rot_matrix @ pc_centered) + origin_point + + return contour_rotated, start_point_A, start_point_P + else: + return contour_rotated, AC_startpoint_idx, PC_startpoint_idx + + +def get_endpoints_from_nib(cc_label_nib, paths_csv, subj_id, return_coordinates=True): + cc_mask = cc_label_nib.get_fdata() == 192 + cc_mask = cc_mask[cc_mask.shape[0]//2] + + + posterior_commisure_center = paths_csv.loc[subj_id, 'PC_center_r':'PC_center_s'].to_numpy().astype(float) + anterior_commisure_center = paths_csv.loc[subj_id, 'AC_center_r':'AC_center_s'].to_numpy().astype(float) + + # adjust LR from label coordinates to orig_up coordinates + posterior_commisure_center[0] = 128 + anterior_commisure_center[0] = 128 + + # orientation I, A + # rotate image so anterior and posterior commisure are horizontal + AC_2d = anterior_commisure_center[1:] + PC_2d = posterior_commisure_center[1:] + + return get_endpoints(cc_mask, AC_2d, PC_2d, resolution=cc_label_nib.header.get_zooms()[1], return_coordinates=return_coordinates) + + +if __name__ == "__main__": + from tqdm import tqdm + OUTPUT_TO_RAS = True + PLOT = False + + paths_csv = pd.read_csv('/groups/ag-reuter-2/users/pollakc/corpus_callosum_fornix/pollakc/network/data/found_labels_with_meta_data_difficult_final.csv', index_col=0) + + for subj_id in tqdm(paths_csv.index): + try: + cc_label_nib = nib.load(paths_csv.loc[subj_id, 'label_merged']) + except Exception as e: + import pdb; pdb.set_trace() + print(subj_id, 'error', e) + continue + + + + # if np.sum(cc_mask) < 20: + # print(subj_id, 'skipping') + # continue + + contour, start_point_A, start_point_P = get_endpoints_from_nib(cc_label_nib, paths_csv, subj_id) + + + + # if PLOT: + # # Add visualization + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots(figsize=(10, 8)) + # ax.imshow(cc_mask, cmap='gray') + # ax.plot(contour[1], contour[0], 'b-', label='Contour') + # # Plot initial endpoint estimates + # ax.plot(start_point_A[1], start_point_A[0], 'rx', + # markersize=8) + # ax.plot(start_point_P[1], start_point_P[0], 'rx', + # markersize=8, label='Ours') + # ax.legend() + # ax.set_title(f'Subject: {subj_id}') + # # Save plot if desired + # #plt.savefig(f'./endpoint_plots/{subj_id}.png', dpi=300, bbox_inches='tight') + # plt.show() + # plt.close() + + + if OUTPUT_TO_RAS: + # use vox2ras matrix to convert to mm + vox2ras_matrix = cc_label_nib.affine + + # Add a third dimension (z) with 0 and a fourth dimension (homogeneous coordinate) with 1 + contour_homogeneous = np.vstack([contour, np.zeros(contour.shape[1]), np.ones(contour.shape[1])]) + start_point_A_homogeneous = np.hstack([start_point_A, [0, 1]]) + start_point_P_homogeneous = np.hstack([start_point_P, [0, 1]]) + + # Apply the transformation + contour = (vox2ras_matrix @ contour_homogeneous)[:3, :] + start_point_A = (vox2ras_matrix @ start_point_A_homogeneous)[:3] + start_point_P = (vox2ras_matrix @ start_point_P_homogeneous)[:3] + + + np.save(f'./contour_data/endpoints_{subj_id}.npy', np.array([start_point_A, start_point_P]), allow_pickle=False) + np.save(f'./contour_data/contours_{subj_id}.npy', np.array(contour), allow_pickle=False) + \ No newline at end of file diff --git a/CorpusCallosum/shape/cc_mesh.py b/CorpusCallosum/shape/cc_mesh.py new file mode 100644 index 00000000..b575d2c7 --- /dev/null +++ b/CorpusCallosum/shape/cc_mesh.py @@ -0,0 +1,1317 @@ +import tempfile +from pathlib import Path + +import numpy as np +import matplotlib +import matplotlib.pyplot as plt +import plotly.graph_objects as go +import nibabel as nib +import lapy +import pyrr +import scipy.interpolate + +from whippersnappy.core import snap1 +from shape.cc_thickness import make_mesh_from_contour, HiddenPrints + + +class CC_Mesh(lapy.TriaMesh): + """A class for representing and manipulating corpus callosum (CC) meshes. + + This class extends lapy.TriaMesh to provide specialized functionality for working with + corpus callosum meshes, including contour management, thickness measurements, and + visualization capabilities. + + The mesh can be constructed from a series of 2D contours representing slices of the + corpus callosum, with optional thickness measurements at various points along these + contours. + + Attributes: + contours (list): List of numpy arrays containing 2D contour points for each slice. + thickness_values (list): List of thickness measurements for each contour point. + start_end_idx (list): List of tuples containing start and end indices for each contour. + ac_coords (numpy.ndarray): Coordinates of the anterior commissure. + pc_coords (numpy.ndarray): Coordinates of the posterior commissure. + resolution (float): Spatial resolution of the mesh. + v (numpy.ndarray): Vertex coordinates of the mesh. + t (numpy.ndarray): Triangle indices of the mesh. + original_thickness_vertices (list): List of vertex indices where thickness was originally measured. + """ + + def __init__(self, num_slices): + """Initialize a CC_Mesh object. + + Args: + num_slices (int): Number of slices in the corpus callosum mesh. + """ + self.contours = [None] * num_slices + self.thickness_values = [None] * num_slices + self.start_end_idx = [None] * num_slices + self.ac_coords = None + self.pc_coords = None + self.resolution = None + self.v = None + self.t = None + self.original_thickness_vertices = [None] * num_slices + + def add_contour(self, slice_idx: int, contour: np.ndarray, thickness_values: np.ndarray, start_end_idx: tuple[int, int] | None = None): + """Add a contour and its associated thickness values for a specific slice. + + Args: + slice_idx (int): Index of the slice where the contour should be added. + contour (numpy.ndarray): Array of shape (N, 2) containing 2D contour points. + thickness_values (numpy.ndarray): Array of thickness measurements for each contour point. + start_end_idx (tuple[int, int], optional): Tuple containing start and end indices for the contour. + If None, defaults to (0, len(contour)//2). + """ + self.contours[slice_idx] = contour + self.thickness_values[slice_idx] = thickness_values + # write vertex indices where thickness values are not nan + self.original_thickness_vertices[slice_idx] = np.where(~np.isnan(thickness_values))[0] + + if start_end_idx is None: + self.start_end_idx[slice_idx] = (0, len(contour)//2) + else: + self.start_end_idx[slice_idx] = start_end_idx + + def set_acpc_coords(self, ac_coords: np.ndarray, pc_coords: np.ndarray): + """Set the coordinates of the anterior and posterior commissure. + + Args: + ac_coords (numpy.ndarray): 3D coordinates of the anterior commissure. + pc_coords (numpy.ndarray): 3D coordinates of the posterior commissure. + """ + self.ac_coords = ac_coords + self.pc_coords = pc_coords + + + + def set_resolution(self, resolution: float): + """Set the spatial resolution of the mesh. + + Args: + resolution (float): Spatial resolution in millimeters. + """ + self.resolution = resolution + + def plot_mesh(self, output_path: str | None = None, + colormap: str = "red_to_yellow", + thickness_overlay: bool = True, + show_contours: bool = False, + show_grid: bool = False, + color_range: tuple[float, float] | None = None, + show_mesh_edges: bool = False, + legend: str = "", + threshold: tuple[float, float] | None = None): + """Plot the mesh using Plotly for better performance and interactivity. + + Creates an interactive 3D visualization of the mesh with optional features like + thickness overlay, contour display, and grid visualization. The plot can be saved + to an HTML file or displayed in a web browser. + + Args: + output_path (str, optional): Path to save the plot. If None, displays the plot interactively. + colormap (str, optional): Which colormap to use. Options are: + - "red_to_blue": Red -> Orange -> Grey -> Light Blue -> Blue + - "red_to_yellow": Red -> Yellow -> Light Blue -> Blue + - "yellow_to_red": Yellow -> Light Blue -> Blue -> Red + - "blue_to_red": Blue -> Light Blue -> Grey -> Orange -> Red + Defaults to "red_to_yellow". + thickness_overlay (bool, optional): Whether to overlay thickness values on the mesh. + Defaults to True. + show_contours (bool, optional): Whether to show the contours. Defaults to False. + show_grid (bool, optional): Whether to show the grid. Defaults to False. + color_range (tuple[float, float], optional): Optional tuple of (min, max) to set fixed + color range. Defaults to None. + show_mesh_edges (bool, optional): Whether to show the mesh edges. Defaults to False. + legend (str, optional): Legend text for the colorbar. Defaults to "". + threshold (tuple[float, float], optional): Values between these thresholds will be shown in grey. + Defaults to (-0.2, 0.2). + """ + assert self.v is not None and self.t is not None, "Mesh has not been created yet" + + if len(self.v) == 0: + print("Warning: No vertices in mesh to plot") + return + + if len(self.t) == 0: + print("Warning: No faces in mesh to plot") + return + + # Define available colormaps + colormaps = { + "red_to_blue": [ + [0.0, "rgb(255,0,0)"], # Bright red + [0.25, "rgb(255,165,0)"], # Light orange + [0.5, "rgb(150,150,150)"], # Dark grey for middle + [0.75, "rgb(173,216,230)"], # Light blue + [1.0, "rgb(0,0,255)"] # Bright blue + ], + "blue_to_red": [ + [0.0, "rgb(0,0,255)"], # Bright blue + [0.25, "rgb(173,216,230)"], # Light blue + [0.5, "rgb(150,150,150)"], # Dark grey for middle + [0.75, "rgb(255,165,0)"], # Light orange + [1.0, "rgb(255,0,0)"] # Bright red + ], + "red_to_yellow": [ + [0.0, "rgb(255,0,0)"], # Bright red + [0.33, "rgb(255,85,0)"], # Red-orange + [0.66, "rgb(255,170,0)"], # Orange + [1.0, "rgb(255,255,0)"] # Yellow + ], + "yellow_to_red": [ + [0.0, "rgb(255,255,0)"], # Yellow + [0.33, "rgb(255,170,0)"], # Orange + [0.66, "rgb(255,85,0)"], # Red-orange + [1.0, "rgb(255,0,0)"] # Bright red + ] + } + + # Select the colormap + if colormap not in colormaps: + print(f"Warning: Unknown colormap '{colormap}'. Using 'red_to_blue' instead.") + colormap = "red_to_blue" + + selected_colormap = colormaps[colormap] + + # If threshold is provided, modify the colormap to include grey region + if threshold is not None and thickness_overlay and hasattr(self, 'mesh_vertex_colors'): + data_min = np.min(self.mesh_vertex_colors) if color_range is None else color_range[0] + data_max = np.max(self.mesh_vertex_colors) if color_range is None else color_range[1] + data_range = data_max - data_min + + # Calculate normalized threshold positions + thresh_low = (threshold[0] - data_min) / data_range + thresh_high = (threshold[1] - data_min) / data_range + + # Ensure thresholds are within [0,1] + thresh_low = max(0, min(1, thresh_low)) + thresh_high = max(0, min(1, thresh_high)) + + # Create new colormap with grey threshold region + grey_color = "rgb(150,150,150)" # Medium grey + new_colormap = [] + + # Add colors before threshold with adjusted positions + if thresh_low > 0: + for pos, color in selected_colormap: + if pos < 1: # Only use positions less than 1 + new_pos = pos * thresh_low + new_colormap.append([new_pos, color]) + + # Add threshold boundaries with grey + new_colormap.extend([ + [thresh_low, grey_color], + [thresh_high, grey_color] + ]) + + # Add colors after threshold with adjusted positions + if thresh_high < 1: + remaining_range = 1 - thresh_high + for pos, color in selected_colormap: + if pos > 0: # Only use positions greater than 0 + new_pos = thresh_high + pos * remaining_range + if new_pos <= 1: # Ensure we don't exceed 1 + new_colormap.append([new_pos, color]) + + selected_colormap = new_colormap + + # Calculate data ranges and center + xyz_min = self.v.min(axis=0) + xyz_max = self.v.max(axis=0) + xyz_range = xyz_max - xyz_min + max_range = xyz_range.max() + center = (xyz_max + xyz_min) / 2 + + # Create mesh plot + fig = go.Figure() + + # Add the mesh as a surface + mesh_args = { + 'x': self.v[:, 0], + 'y': self.v[:, 1], + 'z': self.v[:, 2], + 'i': self.t[:, 0], # First vertex of each triangle + 'j': self.t[:, 1], # Second vertex + 'k': self.t[:, 2], # Third vertex + 'hoverinfo': 'skip', + 'lighting': dict(ambient=0.9, diffuse=0.1, roughness=0.3) + } + + if thickness_overlay and hasattr(self, 'mesh_vertex_colors'): + mesh_args.update({ + 'intensity': self.mesh_vertex_colors, # Add intensity values for colorbar + 'showscale': True, + 'colorbar': dict( + title=dict( + text=legend, + font=dict(size=35, color='white'), # Increase title font size and make white + side='right' # Place title on right side + ), + len=0.55, # Make colorbar shorter + thickness=35, # Make colorbar wider + tickfont=dict(size=30, color='white'), # Increase tick font size and make white + tickformat='.1f', # Show one decimal place + ), + 'opacity': 1, + 'colorscale': selected_colormap + }) + + # Set the colorbar range + if color_range is not None: + mesh_args['cmin'] = color_range[0] + mesh_args['cmax'] = color_range[1] + else: + # Use data range if no explicit range is provided + mesh_args['cmin'] = np.min(self.mesh_vertex_colors) + mesh_args['cmax'] = np.max(self.mesh_vertex_colors) + else: + mesh_args['color'] = 'lightsteelblue' + + fig.add_trace(go.Mesh3d(**mesh_args)) + + if show_contours: + # Add contour polylines for reference + num_slices = len(self.contours) + + # Calculate z coordinates for each slice - use same calculation as in create_mesh + lr_center = self.v[len(self.v)//2][2] + z_coordinates = np.arange(num_slices) * self.resolution - (num_slices // 2) * self.resolution + lr_center + + for i in range(num_slices): + if self.contours[i] is not None: + # Use slice position for z coordinate + z_coord = z_coordinates[i] + contour = self.contours[i] + + # Create 3D points with fixed z coordinate + v_i = np.hstack([contour, np.full((len(contour), 1), z_coord)]) + + # Close the contour by adding the first point at the end + v_i = np.vstack([v_i, v_i[0]]) + + fig.add_trace(go.Scatter3d( + x=v_i[:, 0], + y=v_i[:, 1], + z=v_i[:, 2], + mode='lines', + line=dict(color='white', width=2), + opacity=0.5, + hoverinfo='skip', + showlegend=False + )) + if show_mesh_edges: # show the mesh edges + edge_color = 'darkgray' + vertices_in_first_contour = len(self.contours[0]) + + vertices_to_plot_first = np.concatenate([self.v[:vertices_in_first_contour], self.v[None,0]]) + # Add mesh edges for first 900 vertices as one continuous line + fig.add_trace(go.Scatter3d( + x=vertices_to_plot_first[:,0], + y=vertices_to_plot_first[:,1], + z=vertices_to_plot_first[:,2], + mode='lines', + line=dict(color=edge_color, width=8), + opacity=1, + hoverinfo='skip', + showlegend=False + )) + + vertices_in_last_contour = len(self.contours[-1]) + + vertices_before_last_contour = np.sum([len(c) for c in self.contours[:-1]]) + vertices_to_plot_last = np.concatenate([self.v[vertices_before_last_contour:vertices_before_last_contour + vertices_in_last_contour], self.v[None,vertices_before_last_contour]]) + fig.add_trace(go.Scatter3d( + x=vertices_to_plot_last[:,0], + y=vertices_to_plot_last[:,1], + z=vertices_to_plot_last[:,2], + mode='lines', + line=dict(color=edge_color, width=8), + opacity=1, + hoverinfo='skip', + showlegend=False + )) + + + # Calculate axis ranges to maintain equal aspect ratio + ranges = [] + for i in range(3): + axis_range = [ + center[i] - max_range/2, + center[i] + max_range/2 + ] + ranges.append(axis_range) + + # Configure axes and grid visibility + axis_config = dict( + showgrid=show_grid, + showline=show_grid, + zeroline=show_grid, + showbackground=show_grid, + showticklabels=show_grid, + gridcolor='white', + tickfont=dict(color='white'), + title=dict(font=dict(color='white')) + ) + + fig.update_layout( + scene=dict( + xaxis=dict(range=ranges[0], **{**axis_config, 'title': 'AP' if show_grid else ''}), + yaxis=dict(range=ranges[1], **{**axis_config, 'title': 'SI' if show_grid else ''}), + zaxis=dict(range=ranges[2], **{**axis_config, 'title': 'LR' if show_grid else ''}), + camera=dict( + eye=dict(x=1.5, y=1.5, z=1), + up=dict(x=0, y=0, z=1) + ), + aspectmode='cube', # Force equal aspect ratio + aspectratio=dict(x=1, y=1, z=1), + bgcolor='black', + dragmode='orbit' # Enable orbital rotation by default + ), + showlegend=False, + margin=dict(l=0, r=100, t=0, b=0), # Increased right margin for colorbar + paper_bgcolor='black', + plot_bgcolor='black' + ) + + if output_path is not None: + fig.write_html(output_path) # Save as interactive HTML + else: + # For non-interactive display, save to a temporary HTML and open in browser + import tempfile + import webbrowser + import os + + temp_path = os.path.join(tempfile.gettempdir(), 'cc_mesh_plot.html') + fig.write_html(temp_path) + webbrowser.open('file://' + temp_path) + + + def get_contour_edge_lengths(self, contour_idx): + """Get the lengths of the edges of a contour. + + Args: + contour_idx (int): Index of the contour to get the edge lengths for. + + Returns: + numpy.ndarray: Array of edge lengths for the contour. + """ + edges = np.diff(self.contours[contour_idx], axis=0) + return np.sqrt(np.sum(edges**2, axis=1)) + + + @staticmethod + def make_triangles_between_contours(contour1, contour2): + """Creates a triangular mesh between two contours using a robust method. + + This method creates triangles that connect two contours by matching points between them. + It starts from the closest point on contour2 to the first point of contour1 and creates + triangles by connecting corresponding points. + + Args: + contour1 (numpy.ndarray): First contour points of shape (N, 2). + contour2 (numpy.ndarray): Second contour points of shape (M, 2). + + Returns: + numpy.ndarray: Array of triangle indices of shape (K, 3) where K is the number of triangles. + """ + start_idx_c1 = 0 + # get closest point on contour2 to contour1[0] + start_idx_c2 = np.argmin(np.linalg.norm(contour2 - contour1[0], axis=1)) + + triangles = [] + n1 = len(contour1) + n2 = len(contour2) + + for i in range(n1): + # Current and next indices for contour1 + c1_curr = (start_idx_c1 + i) % n1 + c1_next = (start_idx_c1 + i + 1) % n1 + + # Current and next indices for contour2, offset by n1 to account for vertex stacking + c2_curr = ((start_idx_c2 + i) % n2) + n1 + c2_next = ((start_idx_c2 + i + 1) % n2) + n1 + + # Create two triangles to form a quad between the contours + triangles.append([c1_curr, c2_curr, c1_next]) + triangles.append([c2_curr, c2_next, c1_next]) + + return np.array(triangles) + + + def _create_levelpaths(self, contour_idx, points, trias, num_points=None): + + # # compute poisson + with HiddenPrints(): + cc_tria = lapy.TriaMesh(points, trias) + # extract boundary curve + bdr = np.array(cc_tria.boundary_loops()[0]) + + # find index of endpoints in bdr list + iidx1=np.where(bdr==self.start_end_idx[contour_idx][0])[0][0] + iidx2=np.where(bdr==self.start_end_idx[contour_idx][1])[0][0] + + # create boundary condition (0 at endpoints, -1 on one side, 1 on the other): + if iidx1 > iidx2: + tmp= iidx2 + iidx2 = iidx1 + iidx1 = tmp + dcond = np.ones(bdr.shape) + dcond[iidx1] =0 + dcond[iidx2] =0 + dcond[iidx1+1:iidx2] = -1 + + + # Extract path + with HiddenPrints(): + fem = lapy.Solver(cc_tria) + vfunc = fem.poisson(0,(bdr,dcond)) + if num_points is not None: + # TODO: do midline stuff + level = 0 + midline_equidistant, midline_length = cc_tria.level_path(vfunc, level, n_points=num_points+2) + midline_equidistant = midline_equidistant[:,:2] + eval_points = midline_equidistant + else: + eval_points = self.contours[contour_idx] + gf = lapy.diffgeo.compute_rotated_f(cc_tria,vfunc) + + + + + + # interpolate midline to get levels to evaluate + gf_interp = scipy.interpolate.griddata(cc_tria.v[:,0:2], gf, eval_points, method='nearest') + + # sort by value + sorting_idx_gf = np.argsort(gf_interp) + gf_interp = gf_interp[sorting_idx_gf] + sorted_thickness_values = self.thickness_values[contour_idx][sorting_idx_gf] + + # get levels to evaluate + #level_length = tria.level_length(gf, gf_interp) + + levelpaths = [] + thickness_values = [] + + for i in range(0,len(eval_points)): + level = gf_interp[i] + # levelpath starts at index zero + if level == 0: + continue + lvlpath, lvlpath_length, tria_idx = cc_tria.level_path(gf, level, get_tria_idx=True) + + levelpaths.append(lvlpath) + thickness_values.append(sorted_thickness_values[i]) + + return levelpaths, thickness_values + + def _create_cap(self, points, trias, contour_idx): + + levelpaths, thickness_values = self._create_levelpaths(contour_idx, points, trias) + + # Create mesh from level paths + level_vertices = [] + level_faces = [] + level_colors = [] + vertex_counter = 0 + sorted_thickness_values = np.array(thickness_values) + + # smooth thickness values + from scipy.ndimage import gaussian_filter1d + for i in range(3): + sorted_thickness_values = gaussian_filter1d(sorted_thickness_values, sigma=5) + + NUM_LEVELPOINTS = 50 + + assert len(sorted_thickness_values) == len(levelpaths) + + # TODO: handle gap between first/last levelpath and contour + for idx, levelpath1 in enumerate(levelpaths): + + levelpath1 = lapy.TriaMesh._TriaMesh__iterative_resample_polygon(levelpath1, NUM_LEVELPOINTS) + level_vertices.append(levelpath1) + level_colors.append(np.full((len(levelpath1)), sorted_thickness_values[idx])) + if idx + 1 < len(levelpaths): + levelpath2 = lapy.TriaMesh._TriaMesh__iterative_resample_polygon(levelpaths[idx + 1], NUM_LEVELPOINTS) + + # Create faces between the two paths by connecting vertices + faces_between = [] + i, j = 0, 0 + + while i < len(levelpath1)-1 and j < len(levelpath2)-1: + faces_between.append([i, i+1, len(levelpath1)+j]) + faces_between.append([i+1, len(levelpath1)+j+1, len(levelpath1)+j]) + + i += 1 + j += 1 + + while i < len(levelpath1)-1: + faces_between.append([i, i+1, len(levelpath1)+j]) + i += 1 + + while j < len(levelpath2)-1: + faces_between.append([i, len(levelpath1)+j+1, len(levelpath1)+j]) + j += 1 + + if faces_between: + faces_between = np.array(faces_between) + level_faces.append(faces_between + vertex_counter) + + vertex_counter += len(levelpath1) + + # Convert to numpy arrays + level_vertices = np.vstack(level_vertices) + level_faces = np.vstack(level_faces) + level_colors = np.concatenate(level_colors) + + return level_vertices, level_faces, level_colors + + def create_mesh(self, lr_center: float = 0, closed: bool = False, smooth: int = 0): + """Creates a surface mesh by triangulating between consecutive contours. + + This method constructs a 3D mesh from the stored contours by creating triangles between + adjacent slices. It can optionally create a closed mesh by adding caps at the ends and + apply smoothing. + + Args: + lr_center (float, optional): Center position in the left-right axis. Defaults to 0. + closed (bool, optional): Whether to create a closed mesh by adding caps. Defaults to False. + smooth (int, optional): Number of smoothing iterations to apply. Defaults to 0. + """ + # Filter out None contours and get their indices + valid_contours = [(i, c) for i, c in enumerate(self.contours) if c is not None] + if not valid_contours: + print("Warning: No valid contours found") + self.v = np.array([]) + self.t = np.array([]) + return + + # Calculate z coordinates for each slice + z_coordinates = np.arange(len(valid_contours)) * self.resolution - (len(valid_contours) // 2) * self.resolution + lr_center + + # Build vertices list with z-coordinates + vertices = [] + faces = [] + vertex_start_indices = [] # Track starting index for each contour + current_index = 0 + + for i, (idx, contour) in enumerate(valid_contours): + vertex_start_indices.append(current_index) + vertices.append(np.hstack([contour, np.full((len(contour), 1), z_coordinates[i])])) + + # Check if there's a next valid contour to connect to + if i + 1 < len(valid_contours): + next_idx, contour2 = valid_contours[i + 1] + faces_between = self.make_triangles_between_contours(contour, contour2) + faces.append(faces_between + current_index) + + current_index += len(contour) + + + + self.set_mesh(vertices, faces, self.thickness_values) + + if smooth > 0: + self.smooth_(smooth) + + + if closed: + # Close the mesh by creating caps on both ends + # Left cap (first slice) - use counterclockwise orientation + left_side_points, left_side_trias = make_mesh_from_contour(self.v[:vertex_start_indices[1]][..., :2]) + left_side_points = np.hstack([left_side_points, np.full((len(left_side_points), 1), z_coordinates[0])]) + + # Right cap (last slice) - reverse points for proper orientation + right_side_points, right_side_trias = make_mesh_from_contour(self.v[vertex_start_indices[-1]:][..., :2]) + right_side_points = np.hstack([right_side_points, np.full((len(right_side_points), 1), z_coordinates[-1])]) + + color_sides = True + if color_sides: + left_side_points, left_side_trias, left_side_colors = self._create_cap(left_side_points, left_side_trias, 0) + right_side_points, right_side_trias, right_side_colors = self._create_cap(right_side_points, right_side_trias, len(self.contours) - 1) + + # reverse right side trias + right_side_trias = right_side_trias[:,::-1] + + + left_side_trias = left_side_trias + current_index + current_index += len(left_side_points) + + right_side_trias = right_side_trias + current_index + current_index += len(right_side_points) + + self.set_mesh([self.v, left_side_points, right_side_points], [self.t, left_side_trias, right_side_trias], [self.mesh_vertex_colors, left_side_colors, right_side_colors]) + + + def fill_thickness_values(self): + """ + Interpolate missing thickness values on the contours by weighted average of nearest known thickness values. + """ + + # For each contour with missing thickness values + for i in range(len(self.contours)): + if self.contours[i] is None or self.thickness_values[i] is None: + continue + + thickness = self.thickness_values[i] + edge_lengths = self.get_contour_edge_lengths(i) + + # Find indices of points with known thickness + known_idx = np.where(~np.isnan(thickness))[0] + + # For each point with unknown thickness + for j in range(len(thickness)): + if not np.isnan(thickness[j]): + continue + + # Find two closest points with known thickness + distances = np.zeros(len(known_idx)) + for k, idx in enumerate(known_idx): + # Calculate distance along contour by summing edge lengths + if idx > j: + distances[k] = np.sum(edge_lengths[j:idx]) + else: + distances[k] = np.sum(edge_lengths[idx:j]) + + # Get indices of two closest points + closest_indices = known_idx[np.argsort(distances)[:2]] + closest_distances = np.sort(distances)[:2] + + # Calculate weights based on inverse distance + weights = 1.0 / closest_distances + weights = weights / np.sum(weights) + + # Calculate weighted average thickness + thickness[j] = np.sum(weights * thickness[closest_indices]) + + self.thickness_values[i] = thickness + + + def smooth_thickness_values(self, iterations: int = 1): + """ + Smooth the thickness values using a Gaussian filter + """ + from scipy.ndimage import gaussian_filter1d + for i in range(len(self.thickness_values)): + if self.thickness_values[i] is not None: + self.thickness_values[i] = gaussian_filter1d(self.thickness_values[i], sigma=5) + + + def plot_contour(self, slice_idx: int, output_path: str): + """Plot a single contour with thickness values. + + Creates a 2D visualization of a specific contour slice with points colored according + to their thickness values. The plot is saved to the specified output path. + + Args: + slice_idx (int): Index of the slice to plot. + output_path (str): Path where to save the plot. + + Raises: + ValueError: If the contour for the specified slice is not set. + """ + + if self.contours[slice_idx] is None: + raise ValueError(f'Contour for slice {slice_idx} is not set') + + contour = self.contours[slice_idx] + + plt.figure(figsize=(15, 10)) + # Get thickness values for this slice + thickness = self.thickness_values[slice_idx] + + # Plot points with colors based on thickness + for i in range(len(contour)): + if np.isnan(thickness[i]): + plt.plot(contour[i,0], contour[i,1], 'o', color='gray', markersize=1) + else: + # Map thickness to color from red to yellow + plt.plot(contour[i,0], contour[i,1], 'o', color=plt.cm.YlOrRd(thickness[i]/np.nanmax(thickness)), markersize=1) + + # Connect points with lines + plt.plot(contour[:,0], contour[:,1], '-', color='black', alpha=0.3, label='Contour') + plt.axis('equal') + plt.xlabel('X') + plt.ylabel('Y') + plt.title(f'CC contour for slice {slice_idx}') + plt.legend() + plt.grid(True) + plt.tight_layout() + plt.savefig(output_path, dpi=300) + + + + def smooth_contour(self, contour_idx, window_size=5): + """ + Smooth a contour using a moving average filter + + Parameters: + ----------- + contour : tuple of arrays + The contour coordinates (x, y) + window_size : int + Size of the smoothing window + + Returns: + -------- + tuple of arrays + The smoothed contour coordinates (x, y) + """ + x, y = self.contours[contour_idx].T + + # Ensure the window size is odd + if window_size % 2 == 0: + window_size += 1 + + # Create a padded version of the arrays to handle the edges + x_padded = np.pad(x, (window_size//2, window_size//2), mode='wrap') + y_padded = np.pad(y, (window_size//2, window_size//2), mode='wrap') + + # Apply moving average + x_smoothed = np.zeros_like(x) + y_smoothed = np.zeros_like(y) + + for i in range(len(x)): + x_smoothed[i] = np.mean(x_padded[i:i+window_size]) + y_smoothed[i] = np.mean(y_padded[i:i+window_size]) + + self.contours[contour_idx] = np.array([x_smoothed, y_smoothed]).T + + + def plot_cc_contour_with_levelsets(self, contour_idx=0, levelpaths=None, title=None, save_path=None, colorbar=True): + """Plot a contour with levelset visualization. + + Creates a visualization of a contour with interpolated levelsets, useful for + analyzing the thickness distribution across the corpus callosum. + + Args: + contour_idx (int, optional): Index of the contour to plot. Defaults to 0. + levelpaths (list, optional): List of levelset paths. If None, uses stored levelpaths. + title (str, optional): Title for the plot. Defaults to None. + save_path (str, optional): Path to save the plot. If None, displays interactively. + colorbar (bool, optional): Whether to show the colorbar. Defaults to True. + + Returns: + matplotlib.figure.Figure: The created figure object. + """ + + plot_values = np.array(self.thickness_values[contour_idx][~np.isnan(self.thickness_values[contour_idx])][:100])[::-1] + # double plot values with linear interpolation + + + # Create bar plot of thickness values + # fig, ax = plt.subplots(figsize=(10, 4)) + # ax.bar(range(len(plot_values)), plot_values) + # ax.set_xlabel('Point Index') + # ax.set_ylabel('Thickness (mm)') + # ax.set_title('Thickness Distribution') + # ax.set_ylim(0, 0.06) + # ax.invert_xaxis() + # plt.tight_layout() + # plt.show() + + points, trias = make_mesh_from_contour(self.contours[contour_idx], max_volume=0.5, min_angle=25, verbose=False) + + # make points 3D by adding zero + points = np.column_stack([points, np.zeros(len(points))]) + + levelpaths, _ = self._create_levelpaths(contour_idx, points, trias, num_points=99) + + outside_contour = self.contours[contour_idx].T + + # Create a grid of points covering the contour area with higher resolution + x_min, x_max = np.min(outside_contour[0]), np.max(outside_contour[0]) + y_min, y_max = np.min(outside_contour[1]), np.max(outside_contour[1]) + margin = 1 + resolution = 0.05 # Higher resolution for smoother interpolation + x_grid, y_grid = np.meshgrid( + np.arange(x_min - margin, x_max + margin, resolution), + np.arange(y_min - margin, y_max + margin, resolution) + ) + + # Create a path from the outside contour + contour_path = matplotlib.path.Path(np.column_stack([outside_contour[0], outside_contour[1]])) + + # Check which points are inside the contour + points = np.column_stack([x_grid.flatten(), y_grid.flatten()]) + mask = contour_path.contains_points(points).reshape(x_grid.shape) + + # Collect all levelpath points and their corresponding values + # Extend each levelpath at both ends to improve extrapolation + all_level_points_x = [] + all_level_points_y = [] + all_level_values = [] + + for i, path in enumerate(levelpaths): + + if len(path) == 1: + all_level_points_x.append(path[0][0]) + all_level_points_y.append(path[0][1]) + all_level_values.append(plot_values[i]) + continue + + # make levelpath + path = lapy.TriaMesh._TriaMesh__resample_polygon(path, 1000) + + + + # Extend at the beginning: add point in direction opposite to first segment + first_segment = path[1] - path[0] + # standardize length of first segment + first_segment = first_segment / np.linalg.norm(first_segment) * 10 + extension_start = path[0] - first_segment + all_level_points_x.append(extension_start[0]) + all_level_points_y.append(extension_start[1]) + all_level_values.append(plot_values[i]) + + # Add original path points + for point in path: + all_level_points_x.append(point[0]) + all_level_points_y.append(point[1]) + all_level_values.append(plot_values[i]) + + + # Extend at the end: add point in direction of last segment + last_segment = path[-1] - path[-2] + # standardize length of last segment + last_segment = last_segment / np.linalg.norm(last_segment) * 10 + extension_end = path[-1] + last_segment + all_level_points_x.append(extension_end[0]) + all_level_points_y.append(extension_end[1]) + all_level_values.append(plot_values[i]) + + # Convert to numpy arrays + all_level_points_x = np.array(all_level_points_x) + all_level_points_y = np.array(all_level_points_y) + all_level_values = np.array(all_level_values) + + # Use griddata to perform smooth interpolation - using 'linear' instead of 'cubic' + # and properly formatting the input points + grid_values = scipy.interpolate.griddata( + (all_level_points_x, all_level_points_y), + all_level_values, + (x_grid, y_grid), + method='linear', + fill_value=0 + ) + + # smooth the grid_values + grid_values = scipy.ndimage.gaussian_filter(grid_values, sigma=5, radius=5) + + # Apply the mask to only show values inside the contour + masked_values = np.where(mask, grid_values, np.nan) + + + # Sample colormaps (e.g., 'binary' and 'gist_heat_r') + colors1 = plt.cm.binary([0.4] * 128) + colors2 = plt.cm.hot(np.linspace(0.8, 0.1, 128)) + + # Combine the color samples + colors = np.vstack((colors2, colors1)) + + # Create a new colormap + cmap = matplotlib.colors.LinearSegmentedColormap.from_list('my_colormap', colors) + + + + # Plot CC contour with levelsets + fig = plt.figure(figsize=(10,3)) + # Apply a 10-degree rotation to the entire plot + base = plt.gca().transData + transform = matplotlib.transforms.Affine2D().rotate_deg(10) + transform = transform + base + + # Plot the filled contour with interpolated colors + plt.imshow(masked_values, extent=[x_min-margin, x_max+margin, y_min-margin, y_max+margin], + origin='lower', cmap=cmap, alpha=1, interpolation='bilinear', vmin=0, vmax=0.10, transform=transform) + + plt.imshow(masked_values, + extent=[x_min-margin, x_max+margin, y_min-margin, y_max+margin], + origin='lower', cmap=cmap, alpha=1, interpolation='bilinear', + vmin=0, vmax=0.10, + #norm=LogNorm(vmin=1e-3, vmax=0.1), # Set minimum to avoid log(0) + transform=transform) + + if colorbar: + # Add a colorbar + cbar = plt.colorbar(aspect=10) + cbar.ax.set_ylim(0.001, 0.054) + cbar.ax.set_yticks([0.0, 0.01, 0.02, 0.03, 0.04, 0.05]) + #cbar.ax.set_yticks([0.001, 0.01, 0.05]) + #cbar.ax.set_yticklabels(['0.001', '0.01', '0.05']) + cbar.set_label('p-value (log scale)') + + # Plot the outside contour on top for clear boundary + plt.plot(outside_contour[0], outside_contour[1], 'k-', linewidth=2, label='CC Contour', transform=transform) + + + # plot levelpaths + #for i, path in enumerate(levelpaths): + # plt.plot(path[:,0], path[:,1], 'k--', linewidth=1, alpha=0.2, transform=transform) + # plot midline + # if midline_equidistant is not None: + # midline_x, midline_y = zip(*midline_equidistant) + # plt.plot(midline_x, midline_y, 'k--', linewidth=2, transform=transform, alpha=0.2) + + + plt.axis('equal') + plt.title(title, fontsize=14, fontweight='bold') + #plt.legend(loc='best') + plt.gca().invert_xaxis() + plt.axis('off') + #plt.tight_layout() + # plt.ylim(-105, -75) + # plt.xlim(181, 101) + if save_path is not None: + plt.savefig(save_path, dpi=300) + plt.show() + return fig + + + + def set_mesh(self, vertices, faces, thickness_values=None): + """Set the mesh vertices, faces, and optional thickness values. + + Args: + vertices (list or numpy.ndarray): List of vertex coordinates or array of shape (N, 3). + faces (list or numpy.ndarray): List of face indices or array of shape (M, 3). + thickness_values (list or numpy.ndarray, optional): Thickness values for each vertex. + """ + # Handle case when there are no faces (single contour) + if not faces: + # For single contour, just store vertices without creating a mesh + vertices_array = np.vstack(vertices) if vertices else np.array([]).reshape(0, 3) + self.v = vertices_array + self.t = np.array([]).reshape(0, 3) + # Initialize fsinfo attribute that lapy expects + self.fsinfo = None + # Skip parent initialization since we have no faces + else: + super().__init__(np.vstack(vertices), np.vstack(faces)) + + if thickness_values is not None: + # Filter out empty thickness arrays and concatenate + valid_thickness = [tv for tv in thickness_values if tv is not None and len(tv) > 0] + if valid_thickness: + self.mesh_vertex_colors = np.concatenate(valid_thickness) + else: + self.mesh_vertex_colors = np.array([]) + + @staticmethod + def __create_cc_viewmat(): + """ + Create the view matrix for a nice view of the corpus callosum. + """ + viewLeft = np.array([[ 0, 0,-1, 0], [-1, 0, 0, 0], [ 0, 1, 0, 0], [ 0, 0, 0, 1]]) # left w top up // right + transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) + viewmat = transl * viewLeft + + + + #rotate 10 degrees around x axis + rot = pyrr.Matrix44.from_x_rotation(np.deg2rad(-10)) + viewmat = viewmat * rot + + # rotate 35 degrees around y axis + rot = pyrr.Matrix44.from_y_rotation(np.deg2rad(35)) + viewmat = viewmat * rot + + # rotate 10 degrees around z axis + rot = pyrr.Matrix44.from_z_rotation(np.deg2rad(-8)) + viewmat = viewmat * rot + + return viewmat + + def snap_cc_picture(self, output_path: str): + """Snap a picture of the corpus callosum mesh. + + Takes a snapshot of the mesh from a predefined viewpoint, with optional thickness + overlay. The image is saved to the specified output path. + + Args: + output_path (str): Path where to save the snapshot image. + + Note: + This method uses a temporary file to store the mesh and overlay data during + the snapshot process. + """ + # Skip snapshot if there are no faces + if len(self.t) == 0: + print("Warning: Cannot create snapshot - no faces in mesh") + return + + # create temp file + temp_file = tempfile.NamedTemporaryFile(suffix='.fssurf', delete=True) + self.write_fssurf(temp_file.name) + + # Write thickness values as overlay + if hasattr(self, 'mesh_vertex_colors'): + overlay_file = tempfile.NamedTemporaryFile(suffix='.w', delete=True) + # Write thickness values in FreeSurfer .w format + nib.freesurfer.write_morph_data(overlay_file.name, self.mesh_vertex_colors) + overlaypath = overlay_file.name + else: + overlaypath = None + + snap1( + temp_file.name, + overlaypath=overlaypath, + view=None, + viewmat=self.__create_cc_viewmat(), + width=3*500, + height=3*300, + outpath=output_path, + ambient=0.6, + colorbar_scale=0.5, + colorbar_y=0.88, + colorbar_x=0.19, + brain_scale=2.1, + fthresh=0, + caption='Corpus Callosum thickness (mm)', + caption_y=0.85, + caption_x=0.17, + caption_scale=0.5 + ) + + temp_file.close() + overlay_file.close() + + def smooth_(self, iterations: int = 1): + """Smooth the mesh while preserving the z-coordinates. + + This method applies Laplacian smoothing to the mesh vertices while keeping + the z-coordinates unchanged to maintain the slice structure. + + Args: + iterations (int, optional): Number of smoothing iterations. Defaults to 1. + """ + z_values = self.v[:, 2] + super().smooth_(iterations) + self.v[:, 2] = z_values + + + def save_contours(self, output_path: str): + """Save the contours to a CSV file. + + Saves all contours and their associated endpoint indices to a CSV file. + The file format is: + slice_idx,x,y + where each point of each contour gets its own row, with special lines indicating + the start of new contours and their endpoint indices. + + Args: + output_path (str): Path where to save the CSV file. + """ + with open(output_path, 'w') as f: + # Write header + f.write("slice_idx,x,y\n") + # Write data + for slice_idx, contour in enumerate(self.contours): + if contour is not None: # Skip empty slices + f.write(f"New contour, anterior_endpoint_idx={self.start_end_idx[slice_idx][0]},posterior_endpoint_idx={self.start_end_idx[slice_idx][1]}\n") + for point in contour: + f.write(f"{slice_idx},{point[0]},{point[1]}\n") + + def load_contours(self, input_path: str): + """Load contours from a CSV file. + + Loads contours and their associated endpoint indices from a CSV file. + The file format should match that produced by save_contours: + slice_idx,x,y with special lines for endpoint indices. + + Args: + input_path (str): Path to the CSV file containing the contours. + + Note: + This method will reset any existing contours and endpoint indices. + """ + current_points = [] + self.contours = [] + self.start_end_idx = [] + + with open(input_path, 'r') as f: + # Skip header + next(f) + + for line in f: + if line.startswith('New contour'): + # If we have points from previous contour, save them + if current_points: + self.contours.append(np.array(current_points)) + current_points = [] + + # Extract anterior and posterior endpoint indices + # Format: "New contour, anterior_endpoint_idx=X,posterior_endpoint_idx=Y" + parts = line.strip().split(',') + anterior_idx = int(parts[1].split('=')[1]) + posterior_idx = int(parts[2].split('=')[1]) + self.start_end_idx.append((anterior_idx, posterior_idx)) + else: + # Parse point data + slice_idx, x, y = line.strip().split(',') + current_points.append([float(x), float(y)]) + + # Don't forget to add the last contour + if current_points: + self.contours.append(np.array(current_points)) + + # Convert lists to fixed-size arrays + max_slices = max(len(self.contours), len(self.start_end_idx)) + self.contours = self.contours + [None] * (max_slices - len(self.contours)) + self.start_end_idx = self.start_end_idx + [None] * (max_slices - len(self.start_end_idx)) + + def save_thickness_values(self, output_path: str): + """Save thickness values to a CSV file. + + Saves all thickness values to a CSV file in the format: + slice_idx,thickness + where each thickness value gets its own row. + + Args: + output_path (str): Path where to save the CSV file. + """ + with open(output_path, 'w') as f: + # Write header + f.write("slice_idx,thickness\n") + # Write data + for slice_idx, thickness in enumerate(self.thickness_values): + if thickness is not None: # Skip empty slices + for value in thickness: + f.write(f"{slice_idx},{value}\n") + + def load_thickness_values(self, input_path: str, original_thickness_vertices_path: str | None = None): + """Load thickness values from a CSV file. + + Loads thickness values from a CSV file and optionally associates them with specific + vertices using a measurement points file. + + Args: + input_path (str): Path to the CSV file containing thickness values. + original_thickness_vertices_path (str, optional): Path to a file containing the + indices of vertices where thickness was measured. If None, assumes thickness + values correspond to all vertices in order. + + Raises: + ValueError: If the number of thickness values doesn't match the number of + measurement points, or if the number of slices is inconsistent. + """ + data = np.loadtxt(input_path, delimiter=',', skiprows=1) + slice_indices = data[:, 0].astype(int) + values = data[:, 1] + + # Group values by slice_idx + unique_slices = np.unique(slice_indices) + + # split data into slices + loaded_thickness_values = [None] * (max(unique_slices) + 1) + for slice_idx in unique_slices: + mask = slice_indices == slice_idx + loaded_thickness_values[slice_idx] = values[mask] + + if original_thickness_vertices_path is None: + # check that the number of thickness values for each slice is equal to the number of points in the contour + for slice_idx, thickness in enumerate(loaded_thickness_values): + if thickness is not None: + assert len(thickness) == len(self.contours[slice_idx]), \ + "Number of thickness values does not match number of points in the contour, maybe you need to provide the measurement points file" + # fill original_thickness_vertices with all indices + self.original_thickness_vertices = [np.arange(len(self.contours[slice_idx])) for slice_idx in range(len(self.contours))] + else: + loaded_original_thickness_vertices = self._load_thickness_measurement_points(original_thickness_vertices_path) + + if len(loaded_original_thickness_vertices) != len(loaded_thickness_values): + raise ValueError("Number of slices in measurement points does not match number of slices in provided thickness values") + + # check that original_thickness_vertices is equal to number of measurement points for each slice + for slice_idx, vertex_indices in enumerate(loaded_original_thickness_vertices): + if len(vertex_indices) // 2 == len(loaded_thickness_values[slice_idx]) or len(vertex_indices) // 2 == np.sum(~np.isnan(loaded_thickness_values[slice_idx])): + is_thickness_profile = True + elif len(vertex_indices) == len(loaded_thickness_values[slice_idx]) or len(vertex_indices) == np.sum(~np.isnan(loaded_thickness_values[slice_idx])): + is_thickness_profile = False + else: + raise ValueError("Number of measurement points does not match number of thickness values") + + + # create nan thickness value array for each slice + new_thickness_values = [np.full(len(self.contours[slice_idx]), np.nan) for slice_idx in range(len(self.contours))] + for slice_idx, vertex_indices in enumerate(loaded_original_thickness_vertices): + if is_thickness_profile: + new_thickness_values[slice_idx][vertex_indices] = np.concatenate([loaded_thickness_values[slice_idx],loaded_thickness_values[slice_idx][::-1]]) + else: + new_thickness_values[slice_idx][vertex_indices] = loaded_thickness_values[slice_idx][~np.isnan(loaded_thickness_values[slice_idx])] + self.thickness_values = new_thickness_values + + + def to_fs_coordinates(self): + """Convert mesh coordinates to FreeSurfer coordinate system. + + Transforms the mesh vertices from the original coordinate system to the + FreeSurfer coordinate system by reordering axes and applying appropriate offsets. + """ + self.v = self.v[:, [2, 0, 1]] + self.v[:, 1] -= 128 + self.v[:, 2] += 128 + + def write_fssurf(self, filename): + """Write the mesh to a FreeSurfer surface file. + + Args: + filename (str): Path where to save the FreeSurfer surface file. + + Returns: + The result of the parent class's write_fssurf method. + """ + return super().write_fssurf(filename) + + def write_overlay(self, filename): + """Write the thickness values as a FreeSurfer overlay file. + + Args: + filename (str): Path where to save the overlay file. + + Returns: + The result of writing the morph data using nibabel. + """ + return nib.freesurfer.write_morph_data(filename, self.mesh_vertex_colors) + + def save_thickness_measurement_points(self, filename): + """Write the thickness measurement points to a CSV file. + + Saves the indices of vertices where thickness was measured for each slice + in CSV format: slice_idx,vertex_idx + + Args: + filename (str): Path where to save the CSV file. + """ + with open(filename, 'w') as f: + f.write("slice_idx,vertex_idx\n") + for slice_idx, vertex_indices in enumerate(self.original_thickness_vertices): + if vertex_indices is not None: + for vertex_idx in vertex_indices: + f.write(f"{slice_idx},{vertex_idx}\n") + + @staticmethod + def _load_thickness_measurement_points(filename): + """Load thickness measurement points from a CSV file. + + Args: + filename (str): Path to the CSV file containing measurement points. + + Returns: + list: List of arrays containing vertex indices for each slice where + thickness was measured. + """ + data = np.loadtxt(filename, delimiter=',', skiprows=1) + slice_indices = data[:, 0].astype(int) + vertex_indices = data[:, 1].astype(int) + + # Group values by slice_idx + unique_slices = np.unique(slice_indices) + + # split data into slices + original_thickness_vertices = [None] * (max(unique_slices) + 1) + for slice_idx in unique_slices: + mask = slice_indices == slice_idx + original_thickness_vertices[slice_idx] = vertex_indices[mask] + return original_thickness_vertices \ No newline at end of file diff --git a/CorpusCallosum/shape/cc_metrics.py b/CorpusCallosum/shape/cc_metrics.py new file mode 100644 index 00000000..dbb7b7d4 --- /dev/null +++ b/CorpusCallosum/shape/cc_metrics.py @@ -0,0 +1,178 @@ +import numpy as np + +def calculate_cc_index(cc_contour): + """ + Calculate CC index based on three perpendicular measurements. + + Args: + cc_contour: 2xN array of contour points in ACPC space + + Returns: + float: Sum of thicknesses at three measurement points + """ + # Get anterior and posterior points + anterior_idx = np.argmin(cc_contour[0]) # Leftmost point + posterior_idx = np.argmax(cc_contour[0]) # Rightmost point + + # Get the longest line (anterior to posterior) + ap_line = cc_contour[:,posterior_idx] - cc_contour[:,anterior_idx] + ap_length = np.linalg.norm(ap_line) + ap_unit = np.array([-ap_line[1], ap_line[0]]) / ap_length + + # Get midpoint of AP line + midpoint = cc_contour[:,anterior_idx] + (ap_line/2) + + # Get perpendicular direction + + + # Get intersection points with contour for each measurement line + def get_intersections(start_point, direction): + # Get all points above and below the line + points = cc_contour.T - start_point[None,:] + dots = np.dot(points, direction) + signs = np.sign(dots) + sign_changes = np.where(np.diff(signs))[0] + + intersections = [] + for idx in sign_changes: + # Linear interpolation between points + t = -dots[idx] / (dots[idx+1] - dots[idx]) + intersection = cc_contour[:,idx] + t * (cc_contour[:,idx+1] - cc_contour[:,idx]) + intersections.append(intersection) + + return np.array(intersections) + + # Get three measurements + most_anterior_pt = cc_contour[:,anterior_idx] + perpendicular_unit = np.array([-ap_unit[1], ap_unit[0]]) + + + anterior_intersections = get_intersections(most_anterior_pt - 10*perpendicular_unit, ap_unit) + + # sort by x + anterior_intersections = anterior_intersections[np.argsort(anterior_intersections[:,0])] + + middle_ints = get_intersections(midpoint, perpendicular_unit) + + if len(middle_ints) != 2: + print(f"WARNING: The perpendicular line should intersect the contour twice, but it intersects {len(middle_ints)} times") + + # plt.close() + + + + # calculate index + ap_distance = np.linalg.norm(anterior_intersections[0] - anterior_intersections[-1]) + anterior_distance = np.linalg.norm(anterior_intersections[0] - anterior_intersections[1]) + posterior_distance = np.linalg.norm(anterior_intersections[-1] - anterior_intersections[-2]) + top_distance = np.linalg.norm(middle_ints[0] - middle_ints[1]) + + index = (anterior_distance + posterior_distance + top_distance) / ap_distance + + + + + # fig, ax = plt.subplots(figsize=(8, 6)) + + # # Plot the CC contour + # ax.plot(cc_contour[0], cc_contour[1], 'k-', linewidth=1) + # # add line from last to first + # ax.plot([cc_contour[0,-1], cc_contour[0,0]], [cc_contour[1,-1], cc_contour[1,0]], + # 'k-', linewidth=1) + + # # Plot AP line + # ax.plot([cc_contour[0,anterior_idx], cc_contour[0,posterior_idx]], + # [cc_contour[1,anterior_idx], cc_contour[1,posterior_idx]], + # 'r--', linewidth=1)#, label='Anterior-posterior line') + + + # # Plot the three measurement lines + # for i, ints in enumerate(zip(anterior_intersections[:-1], anterior_intersections[1:])): + + # if i != 1: + # ax.plot([ints[0][0], ints[1][0]], [ints[0][1], ints[1][1]], + # 'b-', linewidth=1, label='Measurement line horizontal' if i==0 else None) + + # ax.plot([middle_ints[0,0], middle_ints[1,0]], [middle_ints[0,1], middle_ints[1,1]], + # 'g-', linewidth=1, label='Measurement lines vertical') + + + # print(middle_ints[0,], middle_ints[1,1]) + # print(midpoint[1], midpoint[0]) + # ax.plot([middle_ints[0,0], midpoint[0]], [middle_ints[0,1], midpoint[1]], + # 'r--', linewidth=1)#, label='Superior-inferior line') + + # #plt.scatter(midpoint[0], midpoint[1], color='green', s=20) + + # ax.set_aspect('equal') + # ax.legend() + # # add gray background to CC contour + # # Fill the inside of the contour with a gray shade + # from matplotlib.path import Path + # from matplotlib.patches import PathPatch + + # # Create a path from the contour points + # contour_path = Path(np.array([cc_contour[0], cc_contour[1]]).T) + + # # Create a patch from the path and add it to the axes + # patch = PathPatch(contour_path, facecolor='gray', alpha=0.2, edgecolor=None) + # ax.add_patch(patch) + + # # invert x + # ax.invert_xaxis() + # #ax.set_title('CC Index Measurement Lines') + # plt.axis('off') + # plt.show() + + return index + +if __name__ == "__main__": + import matplotlib.pyplot as plt + from cc_thickness import convert_to_ras + from shape.cc_endpoint_heuristic import get_endpoints + import pandas as pd + import nibabel as nib + from tqdm import tqdm + # Create visualization of CC index measurements + + + paths_csv = pd.read_csv('/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/network/data/found_labels_with_meta_data_difficult_final.csv', index_col=0) + + + for subj_num, subj_id in enumerate(tqdm(paths_csv.index)): + #subj_id = '099f7f5a' + + label_path = paths_csv.loc[subj_id, 'label_merged'] + + try: + cc_label_nib = nib.load(label_path) + except Exception as e: + import pdb; pdb.set_trace() + print(subj_id, 'error', e) + continue + + PC_2d = paths_csv.loc[subj_id, 'PC_center_r':'PC_center_s'].to_numpy().astype(float)[1:] + AC_2d = paths_csv.loc[subj_id, 'AC_center_r':'AC_center_s'].to_numpy().astype(float)[1:] + + + cc_mask = cc_label_nib.get_fdata() == 192 + cc_mask = cc_mask[cc_mask.shape[0]//2] + + + contour, anterior_endpoint_idx, posterior_endpoint_idx = get_endpoints(cc_mask, AC_2d, PC_2d, cc_label_nib.header.get_zooms()[1], return_coordinates=False) + + + + contour = convert_to_ras(contour, cc_label_nib.affine) + + contour_2d=contour#[[2,0]].T[1:] + #contour = contour[[2,0,1]] + + index = calculate_cc_index(contour_2d) + + print(subj_id, index) + + + + + diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/cc_postprocessing.py new file mode 100644 index 00000000..70fdc154 --- /dev/null +++ b/CorpusCallosum/shape/cc_postprocessing.py @@ -0,0 +1,300 @@ +from pathlib import Path + +import nibabel as nib +import numpy as np + +from shape.cc_thickness import convert_to_ras, cc_thickness +from shape.cc_subsegment_contour import subdivide_contour, transform_to_acpc_standard, subsegment_midline_orthogonal, hampel_subdivide_contour +from shape.cc_endpoint_heuristic import get_endpoints +from shape.cc_metrics import calculate_cc_index +from shape.cc_subsegment_contour import get_primary_eigenvector +from shape.cc_mesh import CC_Mesh +from CorpusCallosum.visualization.visualization import plot_contours +from CorpusCallosum.data.read_write import run_in_background +from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE + + +# assert LIA orientation +LIA_ORIENTATION = np.zeros((3,3)) +LIA_ORIENTATION[0,0] = -1 +LIA_ORIENTATION[1,2] = 1 +LIA_ORIENTATION[2,1] = -1 + + +def create_visualization(subdivision_method, result, midslices_data, output_image_path, ac_coords, pc_coords, vox_size, title_suffix=""): + """Helper function to create visualization plots based on subdivision method. + + Args: + subdivision_method: The subdivision method being used + result: Dictionary containing processing results with split_contours and split_contours_hofer_frahm + midslices_data: Slice data for visualization + output_subdir: Directory to save visualization + ac_coords: AC coordinates + pc_coords: PC coordinates + title_suffix: Additional text to append to the title + + Returns: + Process object for background execution + """ + title = f'CC Subsegmentation: {subdivision_method}{title_suffix}' + + if subdivision_method == "shape": + return run_in_background(plot_contours, False, midslices_data, + result['split_contours'], None, result['midline_equidistant'], result['levelpaths'], + output_image_path, ac_coords, pc_coords, vox_size, title) + else: + return run_in_background(plot_contours, False, midslices_data, + None, result['split_contours_hofer_frahm'], result['midline_equidistant'], result['levelpaths'], + output_image_path, ac_coords, pc_coords, vox_size, title) + + +def create_slice_affine(temp_seg_affine, slice_idx, fsaverage_middle): + """Create slice-specific affine transformation matrix. + + Adjusts the input affine transformation matrix for a specific slice by updating + the translation component based on the slice index and fsaverage middle reference. + + Args: + temp_seg_affine (np.ndarray): Base 4x4 affine transformation matrix + slice_idx (int): Index of the slice to transform + fsaverage_middle (int): Reference middle slice index in fsaverage space + + Returns: + np.ndarray: Modified 4x4 affine transformation matrix for the specific slice + """ + slice_affine = temp_seg_affine.copy() + slice_affine[0, 3] = -fsaverage_middle + slice_idx + return slice_affine + + +def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thickness_points, subdivisions, subdivision_method, contour_smoothing, verbose=False): + """Process a single slice for corpus callosum measurements. + + Performs detailed analysis of a corpus callosum slice, including: + - Contour extraction and endpoint detection + - Thickness profile calculation + - Area and perimeter measurements + - Shape-based metrics (circularity, CC index) + - Subdivision into anatomical regions + + Args: + segmentation (np.ndarray): 3D segmentation array + slice_idx (int): Index of the slice to process + ac_coords (np.ndarray): Anterior commissure coordinates + pc_coords (np.ndarray): Posterior commissure coordinates + affine (np.ndarray): 4x4 affine transformation matrix + num_thickness_points (int): Number of points for thickness estimation + subdivisions (list[float]): List of fractions for anatomical subdivisions + subdivision_method (str): Method for contour subdivision ('shape', 'vertical', + 'angular', or 'eigenvector') + contour_smoothing (float): Gaussian sigma for contour smoothing + verbose (bool): Whether to print progress information + + Returns: + dict or None: Dictionary containing measurements if successful, including: + - cc_index: Corpus callosum shape index + - circularity: Shape circularity measure + - areas: Areas of subdivided regions + - midline_length: Length along the midline + - thickness: Array of thickness measurements + - curvature: Array of curvature measurements + - thickness_profile: Thickness measurements along the contour + - total_area: Total area of the CC + - total_perimeter: Total perimeter length + - split_contours: Subdivided contour segments + - split_contours_hofer_frahm: Alternative subdivision (if applicable) + - midline_equidistant: Equidistant points along midline + - levelpaths: Paths for thickness measurements + - thickness_measurement_points: Points where thickness was measured + - slice_index: Index of the processed slice + + Returns None if no CC is found in the slice. + """ + + cc_mask_slice = segmentation[slice_idx] == 192 + if not np.any(cc_mask_slice): + raise ValueError(f'No CC found in slice {slice_idx}') + + + contour, anterior_endpoint_idx, posterior_endpoint_idx = get_endpoints(cc_mask_slice, ac_coords, pc_coords, affine.diagonal()[1], return_coordinates=False, contour_smoothing=contour_smoothing) + contour_1mm = convert_to_ras(contour, affine) + + midline_length, thickness, curvature, midline_equidistant, levelpaths, contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx = cc_thickness(contour_1mm.T, anterior_endpoint_idx, posterior_endpoint_idx, n_points=num_thickness_points) + thickness_profile = [np.sum(np.sqrt(np.diff(np.array(levelpath[:,:2]),axis=0)**2),axis=0) for levelpath in levelpaths] + thickness_profile = np.linalg.norm(np.array(thickness_profile),axis=1) + + contour_acpc, ac_pt_acpc, pc_pt_acpc, rotate_back_acpc = transform_to_acpc_standard(contour_1mm, contour_1mm[:,anterior_endpoint_idx], contour_1mm[:,posterior_endpoint_idx]) + cc_index = calculate_cc_index(contour_acpc) + + # Apply different subdivision methods based on user choice + if subdivision_method == "shape": + areas, split_contours = subsegment_midline_orthogonal(midline_equidistant, subdivisions, contour_1mm, plot=False) + split_contours = [transform_to_acpc_standard(split_contour, contour_1mm[:,anterior_endpoint_idx], contour_1mm[:,posterior_endpoint_idx])[0] for split_contour in split_contours] + split_contours_hofer_frahm = None + elif subdivision_method == "vertical": + areas, split_contours = subdivide_contour(contour_acpc, subdivisions, plot=False) + split_contours_hofer_frahm = split_contours.copy() + elif subdivision_method == "angular": + if not np.allclose(np.diff(subdivisions), np.diff(subdivisions)[0]): + print('Error: Angular subdivision method (Hampel) only supports equidistant subdivision, but got: ', subdivisions) + return None + areas, split_contours = hampel_subdivide_contour(contour_acpc, num_rays=len(subdivisions), plot=False) + split_contours_hofer_frahm = split_contours.copy() + elif subdivision_method == "eigenvector": + pt0, pt1 = get_primary_eigenvector(contour_acpc) + contour_eigen, _, _, rotate_back_eigen = transform_to_acpc_standard(contour_acpc, pt0, pt1) + ac_pt_eigen, _, _, _ = transform_to_acpc_standard(ac_pt_acpc[:, None], pt0, pt1) + ac_pt_eigen = ac_pt_eigen[:, 0] + areas, split_contours = subdivide_contour(contour_eigen, subdivisions, oriented=True, hline_anchor=ac_pt_eigen) + split_contours = [rotate_back_eigen(split_contour) for split_contour in split_contours] + split_contours_hofer_frahm = split_contours.copy() + + total_area = np.sum(areas) + total_perimeter = np.sum(np.sqrt(np.sum((np.diff(contour_1mm, axis=0))**2, axis=1))) + circularity = 4 * np.pi * total_area / (total_perimeter**2) + + # Transform split contours back to original space + split_contours = [rotate_back_acpc(split_contour) for split_contour in split_contours] + if split_contours_hofer_frahm is not None: + split_contours_hofer_frahm = [rotate_back_acpc(split_contour) for split_contour in split_contours_hofer_frahm] + + return { + 'cc_index': cc_index, + 'circularity': circularity, + 'areas': areas, + 'midline_length': midline_length, + 'thickness': thickness, + 'curvature': curvature, + 'thickness_profile': thickness_profile, + 'total_area': total_area, + 'total_perimeter': total_perimeter, + 'split_contours': split_contours, + 'split_contours_hofer_frahm': split_contours_hofer_frahm, + 'midline_equidistant': midline_equidistant, + 'levelpaths': levelpaths, + 'slice_index': slice_idx + }, contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx + + +def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac_coords, pc_coords, + num_thickness_points, subdivisions, subdivision_method, contour_smoothing, + output_dir, debug_image_path=None, vox_size=None, verbose=False, save_template=None): + """Process corpus callosum slices based on selection mode. + + Handles the processing of either a single middle slice, all slices, or a specific slice, + including affine transformations and measurements for each slice. + + Args: + segmentation (np.ndarray): 3D segmentation array + slice_selection (str): Which slices to process ('middle', 'all', or slice number) + temp_seg_affine (np.ndarray): Base affine transformation matrix + midslices (np.ndarray): Array of mid-sagittal slices + ac_coords (np.ndarray): Anterior commissure coordinates + pc_coords (np.ndarray): Posterior commissure coordinates + num_thickness_points (int): Number of points for thickness estimation + subdivisions (list[float]): List of fractions for anatomical subdivisions + subdivision_method (str): Method for contour subdivision + contour_smoothing (float): Gaussian sigma for contour smoothing + output_dir (str): Base output directory + debug_image_path (str, optional): Path for debug visualization image + verbose (bool): Whether to print progress information + save_template (str | Path | None): Directory path where to save template files, or None to skip saving + + Returns: + tuple: Contains: + - list: List of slice processing results + - list: List of background IO processes + """ + slice_results = [] + IO_processes = [] + + if slice_selection == "middle": + cc_mesh = CC_Mesh(num_slices=1) + cc_mesh.set_acpc_coords(ac_coords, pc_coords) + cc_mesh.set_resolution(1) # contour is always scaled to 1 mm + + # Process only the middle slice + slice_idx = segmentation.shape[0] // 2 + slice_affine = create_slice_affine(temp_seg_affine, slice_idx, FSAVERAGE_MIDDLE) + + result, contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx = process_slice(segmentation, slice_idx, ac_coords, pc_coords, slice_affine, + num_thickness_points, subdivisions, subdivision_method, contour_smoothing) + + cc_mesh.add_contour(0, contour_with_thickness[0], contour_with_thickness[1], start_end_idx=(anterior_endpoint_idx, posterior_endpoint_idx)) + + if result is not None: + slice_results.append(result) + # Create visualization + IO_processes.append(create_visualization(subdivision_method, result, midslices, + debug_image_path, ac_coords, pc_coords, vox_size)) + else: + + cc_mesh = CC_Mesh(num_slices=segmentation.shape[0]) + cc_mesh.set_acpc_coords(ac_coords, pc_coords) + cc_mesh.set_resolution(1) # contour is always scaled to 1 mm + + # Process multiple slices or specific slice + if slice_selection == "all": + start_slice = 0 + end_slice = segmentation.shape[0] + else: # specific slice number + slice_idx = int(slice_selection) + start_slice = slice_idx + end_slice = slice_idx + 1 + + for slice_idx in range(start_slice, end_slice): + if verbose: + print(f"Calculating CC measurements for slice {slice_idx+1} of {end_slice-start_slice}") + + # Update affine for this slice + slice_affine = create_slice_affine(temp_seg_affine, slice_idx, FSAVERAGE_MIDDLE) + + # Process this slice + result, contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx = process_slice(segmentation, slice_idx, ac_coords, pc_coords, slice_affine, + num_thickness_points, subdivisions, subdivision_method, contour_smoothing) + + # insert + cc_mesh.add_contour(slice_idx, contour_with_thickness[0], contour_with_thickness[1], start_end_idx=(anterior_endpoint_idx, posterior_endpoint_idx)) + #cc_mesh.plot_contour(slice_idx, output_dir / f'slice_{slice_idx}' / 'contour.png') + + #cc_mesh.plot_contour(slice_idx, output_dir / f'slice_{slice_idx}' / 'contour_filled.png') + + + + if result is not None: + slice_results.append(result) + + # For single slice mode, save to main directory + if slice_selection != "all": + output_subdir = output_dir + else: + # For all slices mode, create per-slice directory + output_subdir = output_dir / f'slice_{slice_idx}' + output_subdir.mkdir(exist_ok=True) + + # Create visualization for this slice + IO_processes.append(create_visualization(subdivision_method, result, midslices[slice_idx:slice_idx+1], + output_subdir, ac_coords, pc_coords, vox_size, f' (Slice {slice_idx})')) + + if save_template is not None: + # Convert to Path object and ensure directory exists + template_dir = Path(save_template) + template_dir.mkdir(parents=True, exist_ok=True) + cc_mesh.save_contours(str(template_dir / 'contours.txt')) + cc_mesh.save_thickness_values(str(template_dir / 'thickness_values.txt')) + cc_mesh.save_thickness_measurement_points(str(template_dir / 'thickness_measurement_points.txt')) + + + if len(cc_mesh.contours) > 1: + cc_mesh.fill_thickness_values() + cc_mesh.create_mesh() + cc_mesh.smooth_(1) + cc_mesh.plot_mesh() + cc_mesh.write_vtk(str(output_dir / 'cc_mesh.vtk')) + cc_mesh.snap_cc_picture(str(output_dir / 'cc_mesh_snap.png')) + + if not slice_results: + print("Error: No valid slices were found for postprocessing") + exit(1) + + return slice_results, IO_processes \ No newline at end of file diff --git a/CorpusCallosum/shape/cc_subsegment_contour.py b/CorpusCallosum/shape/cc_subsegment_contour.py new file mode 100644 index 00000000..4a754793 --- /dev/null +++ b/CorpusCallosum/shape/cc_subsegment_contour.py @@ -0,0 +1,988 @@ +import nibabel as nib +import numpy as np +from scipy.spatial import ConvexHull +from tqdm import tqdm +import pandas as pd + +from shape.cc_thickness import cc_thickness, convert_to_ras +from shape.cc_endpoint_heuristic import get_endpoints + + + +def minimum_bounding_rectangle(points): + """ + Find the smallest bounding rectangle for a set of points. + Returns a set of points representing the corners of the bounding box. + + :param points: an nx2 matrix of coordinates + :rval: an nx2 matrix of coordinates + """ + pi2 = np.pi/2. + points = points.T + + # get the convex hull for the points + hull_points = points[ConvexHull(points).vertices] + + # calculate edge angles + edges = np.zeros((len(hull_points)-1, 2)) + edges = hull_points[1:] - hull_points[:-1] + + angles = np.zeros((len(edges))) + angles = np.arctan2(edges[:, 1], edges[:, 0]) + + angles = np.abs(np.mod(angles, pi2)) + angles = np.unique(angles) + + # find rotation matrices + rotations = np.vstack([ + np.cos(angles), + np.cos(angles-pi2), + np.cos(angles+pi2), + np.cos(angles)]).T + rotations = rotations.reshape((-1, 2, 2)) + + # apply rotations to the hull + rot_points = np.dot(rotations, hull_points.T) + + # find the bounding points + min_x = np.nanmin(rot_points[:, 0], axis=1) + max_x = np.nanmax(rot_points[:, 0], axis=1) + min_y = np.nanmin(rot_points[:, 1], axis=1) + max_y = np.nanmax(rot_points[:, 1], axis=1) + + # find the box with the best area + areas = (max_x - min_x) * (max_y - min_y) + best_idx = np.argmin(areas) + + # return the best box + x1 = max_x[best_idx] + x2 = min_x[best_idx] + y1 = max_y[best_idx] + y2 = min_y[best_idx] + r = rotations[best_idx] + + rval = np.zeros((4, 2)) + rval[0] = np.dot([x1, y2], r) + rval[1] = np.dot([x2, y2], r) + rval[2] = np.dot([x2, y1], r) + rval[3] = np.dot([x1, y1], r) + + return rval + + +def get_area_from_subsegments(split_contours): + # calculate area of each split contour using the shoelace formula + areas = [np.abs(np.trapz(split_contour[1], split_contour[0])) for split_contour in split_contours] + area_out = np.zeros(len(areas)) + for i in range(len(areas)): + if i == len(areas)-1: + area_out[i] = areas[i] + else: + area_out[i] = areas[i] - areas[i+1] + return area_out + + +def subsegment_midline_orthogonal(midline, area_weights, contour, plot=True, ax=None, extremes=None): + + # get points after midline length of splits + + # get vertex closest to midline end + midline_end_idx = np.argmin(np.linalg.norm(contour.T - midline[-1], axis=1)) + # roll contour start to midline end + contour = np.roll(contour, -midline_end_idx, axis=1) + + edge_idx, edge_frac = np.divmod(len(midline) * np.array(area_weights), 1) + edge_idx = edge_idx.astype(int) + split_points = midline[edge_idx] + (midline[edge_idx+1] - midline[edge_idx]) * edge_frac[:,None] + + # get edge for each split point + edge_directions = midline[edge_idx] - midline[edge_idx+1] + # get vector perpendicular to each midline edge + edge_ortho_vectors = np.column_stack((-edge_directions[:,1], edge_directions[:,0])) + edge_ortho_vectors = edge_ortho_vectors / np.linalg.norm(edge_ortho_vectors, axis=1)[:,None] + + split_contours = [] + split_contours.append(contour) + + + for pt_idx,split_point in enumerate(split_points): + intersections = [] + for i in range(contour.shape[1] - 1): + + # get contour segment + segment_start = contour[:, i] + segment_end = contour[:, i + 1] + segment_vector = segment_end - segment_start + + + # Check for intersection with the perpendicular line + matrix = np.array([segment_vector, -edge_ortho_vectors[pt_idx]]).T + if np.linalg.matrix_rank(matrix) < 2: + continue # Skip parallel lines + + # Solve for intersection + t, s = np.linalg.solve(matrix, split_point - segment_start) + if 0 <= t <= 1: + intersection_point = segment_start + t * segment_vector + intersections.append((i, intersection_point)) + + # import matplotlib.pyplot as plt + # plt.figure() + # plt.plot(contour[0], contour[1], 'k-') + # plt.plot(midline[:,0], midline[:,1], 'k--') + # plt.plot(split_point[0], split_point[1], 'ro') + + # plt.plot([segment_start[0], segment_end[0]], [segment_start[1], segment_end[1]], 'bo', linewidth=2) + # plt.plot([split_point[0]-edge_ortho_vectors[pt_idx][0], split_point[0]+edge_ortho_vectors[pt_idx][0]], [split_point[1]-edge_ortho_vectors[pt_idx][1], split_point[1]+edge_ortho_vectors[pt_idx][1]], 'k-', linewidth=2) + # plt.show() + + + # get the two intersections closest to split_point + intersections.sort(key=lambda x: np.linalg.norm(x[1] - split_point)) + + # Create new contours by splitting at intersections + if intersections: + first_index, first_intersection = intersections[1] + second_index, second_intersection = intersections[0] + + if first_index > second_index: + first_index, second_index = second_index, first_index + first_intersection, second_intersection = second_intersection, first_intersection + + first_index += 1 + #second_index += 1 + + # connect first and second half + start_to_cutoff = np.hstack((contour[:, :first_index], first_intersection[:, None], second_intersection[:, None], contour[:, second_index + 1:])) + split_contours.append(start_to_cutoff) + else: + raise ValueError('No intersections found, this should not happen') + + # plot contour to first index, then split point, then contour to second index + + # import matplotlib.pyplot as plt + # plt.close() + # fig, ax = plt.subplots(1,1) + # ax.plot(contour[:, :first_index][0], contour[:, :first_index][1], '-', linewidth=2, color='grey', label='Contour to first index') + # ax.plot(first_intersection[0], first_intersection[1], 'o', markersize=8, color='red', label='First intersection') + # ax.plot(second_intersection[0], second_intersection[1], 'o', markersize=8, color='red', label='Second intersection') + # ax.plot(contour[:, second_index + 1:][0], contour[:, second_index + 1:][1], '-', linewidth=2, color='red', label='Contour to second index') + # ax.legend() + # ax.set_title('Split Contours') + # ax.set_aspect('equal') + # ax.axis('off') + # plt.show() + + + + if plot: + extremes = [midline[0], midline[-1]] + + plot_transform = None + if plot_transform is not None: + split_contours = [plot_transform(split_contour) for split_contour in split_contours] + contour = plot_transform(contour) + extremes = [plot_transform(extreme[:,None]) for extreme in extremes] + split_points = [plot_transform(split_point[:,None]) for split_point in split_points] + split_points_vlines_start = plot_transform(split_points_vlines_start) + split_points_vlines_end = plot_transform(split_points_vlines_end) + + import seaborn as sns + import matplotlib.pyplot as plt + if ax is None: + SHOW = True + fig, ax = plt.subplots(1,1,figsize=(8, 6)) + ax.axis('equal') + else: + SHOW = False + # pretty plot with areas filles in the polygon and overall area annotated + colors = sns.color_palette("ch:start=.2,rot=-.3", len(split_contours)) + for i, (color, split_contour) in enumerate(zip(colors, split_contours)): + ax.fill(split_contour[0], split_contour[1], alpha=0.5, color=color) + #ax.text(np.mean(split_contour[0]), np.mean(split_contour[1]), f'{area_out[i]:.2f}', color='black', fontsize=12) + # plot contour + ax.plot(contour[0], contour[1], '-', linewidth=2, color='grey') + # put text between split points + # add enpoints to split_points + split_points = split_points.tolist() + split_points.insert(0, extremes[0]) + split_points.append(extremes[1]) + #ax.scatter(np.array(split_points)[:,0], np.array(split_points)[:,1], color='black', s=20) + ax.plot(midline[:,0], midline[:,1], 'k--', linewidth=2) + + + # plot edge orthogonal to each split point + for i in range(0,len(edge_ortho_vectors)): + pt = split_points[i+1] + length = 0.4 + ax.plot([pt[0]-edge_ortho_vectors[i][0]*length, pt[0]+edge_ortho_vectors[i][0]*length], [pt[1]-edge_ortho_vectors[i][1]*length, pt[1]+edge_ortho_vectors[i][1]*length], 'k-', linewidth=2) + + # convert area_weights into fraction of total line length + # e.g. area_weights=[1/6, 1/2, 2/3, 3/4] to ['1/6', '2/3', ...] + # cumulative difference + area_weights_diff = [] + area_weights_diff.append(area_weights[0]) + for i in range(1,len(area_weights)): + area_weights_diff.append(area_weights[i] - area_weights[i-1]) + area_weights_diff.append(1 - area_weights[-1]) + + #area_weights_txt = area_weights_txt / area_weights_txt[-1] + from fractions import Fraction + area_weights_txt = [Fraction(area_weights_diff[i]).limit_denominator(1000) for i in range(len(area_weights_diff))] + + for i in range(len(split_points)-1): + # get_index of split_points[i] in midline + sp1_midline_idx = np.argmin(np.linalg.norm(midline - split_points[i], axis=1)) + sp2_midline_idx = np.argmin(np.linalg.norm(midline - split_points[i+1], axis=1)) + + # get midpoint on midline + midpoint_idx = (sp1_midline_idx + sp2_midline_idx) // 2 + midpoint = midline[midpoint_idx] + + # get vector perpendicular to line between split points + vector = np.array(split_points[i+1]) - np.array(split_points[i]) + vector = vector / np.linalg.norm(vector) + vector = np.array([-vector[1], vector[0]]) + + midpoint = midpoint - vector * 3 + #ax.text(midpoint[0]-5, midpoint[1]-5, f'{area_out[i]:.2f}', color='black', fontsize=12) + #ax.text(midpoint[0], midpoint[1], f'{area_weights_txt[i]}', color='black', fontsize=12, horizontalalignment='center', verticalalignment='center') + + + + # start point & end point + ax.plot(extremes[0][0], extremes[0][1], marker='o', markersize=8, color='black') + ax.plot(extremes[1][0], extremes[1][1], marker='o', markersize=8, color='black') + + + # plot contour point 0 + #ax.scatter(contour[0,0], contour[1,0], color='red', s=120) + ax.set_title('Split Contours') + + if SHOW: + ax.axis('off') + ax.invert_xaxis() + ax.axis('equal') + plt.show() + + + return get_area_from_subsegments(split_contours), split_contours + + + + + +def hampel_subdivide_contour(contour, num_rays, plot=False, ax=None): + + # Find the extreme points in the x-direction + min_x_index = np.argmin(contour[0]) + contour = np.roll(contour, -min_x_index, axis=1) + + # get minimal bounding box around contour + min_bounding_rectangle = minimum_bounding_rectangle(contour) + + # get long edges of rectangle + rectangle_duplicate_last = np.vstack((min_bounding_rectangle, min_bounding_rectangle[0])) + long_edges = np.diff(rectangle_duplicate_last, axis=0) + long_edges = np.linalg.norm(long_edges, axis=1) + long_edges_idx = np.argpartition(long_edges, -2)[-2:] + + + # select lower long edge + min_val = np.inf + min_idx = None + for i in long_edges_idx: + + if rectangle_duplicate_last[i][1] < min_val: + min_val = rectangle_duplicate_last[i][1] + min_idx = i + + if rectangle_duplicate_last[i+1][1] < min_val: + min_val = rectangle_duplicate_last[i+1][1] + min_idx = i + + lowest_points = rectangle_duplicate_last[[min_idx, min_idx+1]] + + # sort lowest points by x coordinate + if lowest_points[0,0] < lowest_points[1,0]: + lowest_points = lowest_points[::-1] + + + # get midpoint of lower edge of rectangle + midpoint_lower_edge = np.mean(lowest_points, axis=0) + + # get angle of lower edge of rectangle to x-axis + angle_lower_edge = np.arctan2(lowest_points[1, 1] - lowest_points[0, 1], lowest_points[1, 0] - lowest_points[0, 0]) #% (np.pi) + + #steps = np.pi / num_rays + + #print(np.degrees(angle_lower_edge)) + # get angles for equally spaced rays + angles = np.linspace(-angle_lower_edge, -angle_lower_edge + np.pi, num_rays+2, endpoint=True) #+ np.pi *3 + angles = angles[1:-1] + + # create ray vectors + ray_vectors = np.vstack((np.cos(angles), np.sin(angles))) + # make ray vectors unit length + ray_vectors = ray_vectors / np.linalg.norm(ray_vectors, axis=0) + + # invert x of ray vectors + ray_vectors[0] = -ray_vectors[0] + + + # Subdivision logic + split_contours = [] + for ray_vector in ray_vectors.T: + intersections = [] + for i in range(contour.shape[1] - 1): + segment_start = contour[:, i] + segment_end = contour[:, i + 1] + segment_vector = segment_end - segment_start + + # Check for intersection with the ray + matrix = np.array([segment_vector, -ray_vector]).T + if np.linalg.matrix_rank(matrix) < 2: + continue # Skip parallel lines + + # Solve for intersection + t, s = np.linalg.solve(matrix, midpoint_lower_edge - segment_start) + if 0 <= t <= 1: + intersection_point = segment_start + t * segment_vector + intersections.append((i, intersection_point)) + + # Sort intersections by their position along the contour + intersections.sort() + + + # Create new contours by splitting at intersections + if intersections: + first_index, first_intersection = intersections[0] + second_index, second_intersection = intersections[-1] + + start_to_cutoff = np.hstack((contour[:, :first_index], first_intersection[:, None], second_intersection[:, None], contour[:, second_index + 1:])) + + # connect first and second half + split_contours.append(start_to_cutoff) + else: + raise ValueError('No intersections found, this should not happen') + + split_contours.append(contour) + split_contours = split_contours[::-1] + + + #split_contours = split_contours[::-1] + + # Plotting logic + if plot: + import seaborn as sns + import matplotlib.pyplot as plt + if ax is None: + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + ax.axis('equal') + SHOW = True + else: + SHOW = False + min_bounding_rectangle_plot = np.vstack((min_bounding_rectangle, min_bounding_rectangle[0])) + #ax.plot(contour[0], contour[1], 'b-', label='Original Contour') + ax.plot(min_bounding_rectangle_plot[:,0], min_bounding_rectangle_plot[:,1], 'k--') + ax.plot(midpoint_lower_edge[0], midpoint_lower_edge[1], 'ko', markersize=8) + for i, ray_vector in enumerate(ray_vectors.T): + ray_length = 15 + ray_vector *= -ray_length + ax.plot([midpoint_lower_edge[0], midpoint_lower_edge[0] + ray_vector[0]], + [midpoint_lower_edge[1], midpoint_lower_edge[1] + ray_vector[1]], + 'k--') + # pretty plot with areas filles in the polygon and overall area annotated + colors = sns.color_palette("ch:start=.2,rot=-.3", len(split_contours)) + for i, (color, split_contour) in enumerate(zip(colors, split_contours)): + ax.fill(split_contour[0], split_contour[1], alpha=0.5, color=color) + ax.plot(contour[0], contour[1], '-', linewidth=2, color='grey') + + ax.set_title('Split Contours') + ax.axis('off') + if SHOW: + ax.axis('equal') + plt.show() + + + return get_area_from_subsegments(split_contours), split_contours + + +def subdivide_contour(contour, area_weights, plot=False, ax=None, plot_transform=None, oriented=False, hline_anchor=None): + + # Find the extreme points in the x-direction + min_x_index = np.argmax(contour[0]) + contour = np.roll(contour, -min_x_index, axis=1) + + min_x_index = 0 + max_x_index = np.argmin(contour[0]) + + + + + if oriented: + + contour_x_sorted = np.sort(contour[0]) + min_x = contour_x_sorted[0] + max_x = contour_x_sorted[-1] + extremes = (np.array([max_x, 0]), np.array([min_x, 0])) + + if hline_anchor is not None: + extremes = (np.array([max_x, hline_anchor[1]]), np.array([min_x, hline_anchor[1]])) + + + + # only keep x values of extremes and set y 5 mm below most inferior point of contour + # if hline_anchor is None: + # most_inferior_point = np.min(contour[1]) + # extremes = (np.array([extremes[0][0], most_inferior_point - 5]), np.array([extremes[1][0], most_inferior_point - 5])) + # else: + # # get y diffrence between extremes and hline_anchor + # y_diff = extremes[1][1] - hline_anchor[1] + # extremes = (np.array([extremes[0][0], extremes[0][1] - y_diff]), np.array([extremes[1][0], extremes[1][1] - y_diff])) + else: + extremes = (contour[:, min_x_index].copy(), contour[:, max_x_index].copy()) + # Calculate the line between the extreme points + start_point, end_point = extremes + line_vector = end_point - start_point + line_length = np.linalg.norm(line_vector) + + # Normalize the line vector + line_unit_vector = line_vector / line_length + + # Calculate the perpendicular vector + perp_vector = np.array([-line_unit_vector[1], line_unit_vector[0]]) + perp_vector = perp_vector / np.linalg.norm(perp_vector) + + if hline_anchor is None: + most_inferior_point = np.min(contour[1]) + # move extreme 1 down 5 mm below inferior point and extreme 2 the same distance (so the angle stays the same) + down_distance = (extremes[1][1] - most_inferior_point) * 1.3 + start_point = extremes[0] + down_distance * perp_vector + end_point = extremes[1] + down_distance * perp_vector + + else: + # get closest point on line to hline_anchor + intersection = start_point + line_unit_vector * np.dot(hline_anchor - start_point, line_unit_vector) + # get distance closest point on line to hline_anchor + distance = np.linalg.norm(intersection - hline_anchor) + + # import matplotlib.pyplot as plt + # plt.close() + # fig, ax = plt.subplots(1,1,figsize=(8, 6)) + # ax.plot(contour[0], contour[1], 'b-', label='Original Contour') + # ax.plot(hline_anchor[0], hline_anchor[1], 'ro', label='Hline Anchor') + # ax.plot(intersection[0], intersection[1], 'go', label='Intersection') + # ax.plot(start_point[0], start_point[1], 'bo', label='Start Point') + # ax.plot(end_point[0], end_point[1], 'go', label='End Point') + # ax.legend() + # plt.show() + + # move start and end point the same distance + start_point = extremes[0] + distance * perp_vector + end_point = extremes[1] + distance * perp_vector + + + extremes = (start_point, end_point) + + + + + # Calculate the line between the extreme points + start_point, end_point = extremes + line_vector = end_point - start_point + line_length = np.linalg.norm(line_vector) + + # Normalize the line vector + line_unit_vector = line_vector / line_length + + # Calculate the perpendicular vector + perp_vector = np.array([-line_unit_vector[1], line_unit_vector[0]]) + + # Calculate split points based on area weights + split_points = [] + for i,weight in enumerate(area_weights): + #current_weight = np.sum(area_weights[:i]) + split_distance = weight * line_length + split_point = start_point + split_distance * line_unit_vector + split_points.append(split_point) + + + + # Split the contour at the calculated split points + split_contours = [] + split_contours.append(contour) + for split_point in split_points: + intersections = [] + for i in range(contour.shape[1] - 1): + segment_start = contour[:, i] + segment_end = contour[:, i + 1] + segment_vector = segment_end - segment_start + + # Check for intersection with the perpendicular line + matrix = np.array([segment_vector, -perp_vector]).T + if np.linalg.matrix_rank(matrix) < 2: + continue # Skip parallel lines + + # Solve for intersection + t, s = np.linalg.solve(matrix, split_point - segment_start) + if 0 <= t <= 1: + intersection_point = segment_start + t * segment_vector + intersections.append((i, intersection_point)) + + # Sort intersections by their position along the contour + #intersections.sort() + + # get the two intersections that have the highest y coordinate + intersections.sort(key=lambda x: x[1][1], reverse=True) + + # Create new contours by splitting at intersections + if intersections: + + first_index, first_intersection = intersections[1] + second_index, second_intersection = intersections[0] + + if first_index > second_index: + first_index, second_index = second_index, first_index + first_intersection, second_intersection = second_intersection, first_intersection + + first_index += 1 + #second_index += 1 + + + + #start_to_cutoff = np.hstack((contour[:, :first_index], first_intersection[:, None], second_intersection[:, None], contour[:, second_index + 1:])) + start_to_cutoff = np.hstack((first_intersection[:, None], contour[:, first_index:second_index], second_intersection[:, None])) + + + # import matplotlib.pyplot as plt + # plt.close() + # fig, ax = plt.subplots(1,1,figsize=(8, 6)) + # ax.plot(contour[0], contour[1], 'b-', label='Original Contour') + # ax.plot(contour[0][0], contour[1][0], 'bo', label='First contour point') + # ax.plot(first_intersection[0], first_intersection[1], 'ro', label='First Intersection') + # ax.plot(second_intersection[0], second_intersection[1], 'go', label='Second Intersection') + # # ax.plot(contour[:, :first_index][0], contour[:, :first_index][1]+0.5, 'r-', label='First Segment') + # # ax.plot(contour[:, second_index+1:][0], contour[:, second_index+1:][1]+1, 'g-', label='Second Segment') + # ax.plot(contour[:, first_index:second_index][0], contour[:, first_index:second_index][1]+0.5, 'r-', label='Segment') + # ax.plot(start_to_cutoff[0], start_to_cutoff[1], 'g-', label='Start to Cutoff') + # ax.legend() + # plt.show() + + # connect first and second half + split_contours.append(start_to_cutoff) + else: + raise ValueError('No intersections found, this should not happen') + + + + # if plot: + # import matplotlib.pyplot as plt + # plt.figure(figsize=(8, 6)) + # plt.plot(contour[0], contour[1], 'b-', label='Original Contour') + # plt.plot(extremes[0][0], extremes[0][1], 'rx', markersize=8, label='Start Point') + # plt.plot(extremes[1][0], extremes[1][1], 'gx', markersize=8, label='End Point') + # for i, split_contour in enumerate(split_contours): + # plt.plot(split_contour[0], split_contour[1], label=f'Split {i+1}') + # plt.scatter(split_contour[0], split_contour[1], s=10) # Plot vertices + # plt.title('Split Contours') + # plt.xlabel('X') + # plt.ylabel('Y') + # plt.legend() + # plt.axis('equal') + # plt.show() + + # # same plot but segment are moved apart by 5 mm + # plt.figure(figsize=(8, 6)) + # for i, split_contour in enumerate(split_contours): + # plt.plot(split_contour[0], split_contour[1]+i*5, label=f'Split {i+1}') + # plt.scatter(split_contour[0], split_contour[1]+i*5, s=10) # Plot vertices + # plt.title('Split Contours') + # plt.xlabel('X') + # plt.ylabel('Y') + # plt.legend() + # plt.axis('equal') + # plt.show() + + + + + if plot: + # make vline at every split point + split_points_vlines_start = (np.array(split_points) - perp_vector * 1).T + split_points_vlines_end = (np.array(split_points) + perp_vector * 1).T + + + if oriented: + # make another vline at start point and end point, this time not perpendicular to line but perpendicular to x-axis + start_point_vline = np.array([start_point, np.array([start_point[0], start_point[1]+8])]) + end_point_vline = np.array([end_point, np.array([end_point[0], end_point[1]+8])]) + else: + start_point_vline = np.array([start_point, start_point - perp_vector * 8]) + end_point_vline = np.array([end_point, end_point - perp_vector * 8]) + + if plot_transform is not None: + split_contours = [plot_transform(split_contour) for split_contour in split_contours] + contour = plot_transform(contour) + extremes = [plot_transform(extreme[:,None]) for extreme in extremes] + split_points = [plot_transform(split_point[:,None]) for split_point in split_points] + split_points_vlines_start = plot_transform(split_points_vlines_start) + split_points_vlines_end = plot_transform(split_points_vlines_end) + start_point_vline = plot_transform(start_point_vline.T).T + end_point_vline = plot_transform(end_point_vline.T).T + + import seaborn as sns + import matplotlib.pyplot as plt + if ax is None: + SHOW = True + fig, ax = plt.subplots(1,1,figsize=(8, 6)) + ax.axis('equal') + else: + SHOW = False + # pretty plot with areas filles in the polygon and overall area annotated + colors = sns.color_palette("ch:start=.2,rot=-.3", len(split_contours)) + for i, (color, split_contour) in enumerate(zip(colors, split_contours)): + ax.fill(split_contour[0], split_contour[1], alpha=0.5, color=color) + #ax.text(np.mean(split_contour[0]), np.mean(split_contour[1]), f'{area_out[i]:.2f}', color='black', fontsize=12) + # plot contour + ax.plot(contour[0], contour[1], '-', linewidth=2, color='grey') + # dashed line between start point & end point + ax.plot(np.vstack((extremes[0][0], extremes[1][0])), np.vstack((extremes[0][1], extremes[1][1])), '--', linewidth=2, color='grey') + # markers at every split point + for i in range(split_points_vlines_start.shape[1]): + ax.plot(np.vstack((split_points_vlines_start[:,i][0], split_points_vlines_end[:,i][0])), + np.vstack((split_points_vlines_start[:,i][1], split_points_vlines_end[:,i][1])), 'k-', linewidth=2) + + ax.plot(start_point_vline[:,0], start_point_vline[:,1], '--', linewidth=2, color='grey') + ax.plot(end_point_vline[:,0], end_point_vline[:,1], '--', linewidth=2, color='grey') + # put text between split points + # add enpoints to split_points + split_points.insert(0, extremes[0]) + split_points.append(extremes[1]) + # convert area_weights into fraction of total line length + # e.g. area_weights=[1/6, 1/2, 2/3, 3/4] to ['1/6', '2/3', ...] + # cumulative difference + area_weights_diff = [] + area_weights_diff.append(area_weights[0]) + for i in range(1,len(area_weights)): + area_weights_diff.append(area_weights[i] - area_weights[i-1]) + area_weights_diff.append(1 - area_weights[-1]) + + #area_weights_txt = area_weights_txt / area_weights_txt[-1] + from fractions import Fraction + area_weights_txt = [Fraction(area_weights_diff[i]).limit_denominator(1000) for i in range(len(area_weights_diff))] + + for i in range(len(split_points)-1): + midpoint = np.mean([split_points[i], split_points[i+1]], axis=0) + #ax.text(midpoint[0]-5, midpoint[1]-5, f'{area_out[i]:.2f}', color='black', fontsize=12) + ax.text(midpoint[0], midpoint[1]-5, f'{area_weights_txt[i]}', color='black', fontsize=11, horizontalalignment='center') + + + + # start point & end point + ax.plot(extremes[0][0], extremes[0][1], marker='o', markersize=8, color='black') + ax.plot(extremes[1][0], extremes[1][1], marker='o', markersize=8, color='black') + + + # plot contour 0 point + #ax.scatter(contour[0,0], contour[1,0], color='red', s=100) + + + + + ax.set_title('Split Contours') + # ax.set_xlabel('X') + # ax.set_ylabel('Y') + + # axis off + ax.axis('off') + if SHOW: + ax.axis('equal') + plt.show() + + + return get_area_from_subsegments(split_contours), split_contours + + +def transform_to_acpc_standard(contour_ras, ac_pt_ras, pc_pt_ras): + # translate AC to the origin and PC to (0, ac_pc_dist) + translation_matrix = np.array([[1, 0, -ac_pt_ras[0]], + [0, 1, -ac_pt_ras[1]], + [0, 0, 1]]) + + ac_pc_vec = pc_pt_ras - ac_pt_ras + ac_pc_dist = np.linalg.norm(ac_pc_vec) + + posterior_vector = np.array([-ac_pc_dist, 0]) + + # get angle between ac_pc_vec and posterior_vector + dot_product = np.dot(ac_pc_vec, posterior_vector) + norms_product = np.linalg.norm(ac_pc_vec) * np.linalg.norm(posterior_vector) + theta = np.arccos(dot_product / norms_product) + + # Determine the sign of the angle using cross product + cross_product = np.cross(ac_pc_vec, posterior_vector) + if cross_product < 0: + theta = -theta + + # create rotation matrix for theta + rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1]]) + + # apply translation and rotation + if contour_ras.shape[0] == 2: + contour_ras_homogeneous = np.vstack([contour_ras, np.ones(contour_ras.shape[1])]) + else: + contour_ras_homogeneous = contour_ras + + contour_acpc = (rotation_matrix @ translation_matrix) @ contour_ras_homogeneous + contour_acpc = contour_acpc[:2, :] + + rotate_back = lambda x: (np.linalg.inv(rotation_matrix @ translation_matrix) @ np.vstack([x, np.ones(x.shape[1])]))[:2, :] + return contour_acpc, np.array([0, 0]), np.array([-ac_pc_dist, 0]), rotate_back + +def preprocess_cc(cc_label_nib, paths_csv, subj_id): + cc_mask = cc_label_nib.get_fdata() == 192 + cc_mask = cc_mask[cc_mask.shape[0]//2] + + + posterior_commisure_center = paths_csv.loc[subj_id, 'PC_center_r':'PC_center_s'].to_numpy().astype(float) + anterior_commisure_center = paths_csv.loc[subj_id, 'AC_center_r':'AC_center_s'].to_numpy().astype(float) + + # adjust LR from label coordinates to orig_up coordinates + posterior_commisure_center[0] = 128 + anterior_commisure_center[0] = 128 + + # orientation I, A + # rotate image so anterior and posterior commisure are horizontal + AC_2d = anterior_commisure_center[1:] + PC_2d = posterior_commisure_center[1:] + + return cc_mask, AC_2d, PC_2d + + +def get_primary_eigenvector(contour_ras): + # Center the data by subtracting mean + contour_mean = np.mean(contour_ras, axis=1, keepdims=True) + contour_centered = contour_ras - contour_mean + + # Calculate covariance matrix + cov_matrix = np.cov(contour_centered) + + # Get eigenvalues and eigenvectors using PCA + eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix) + + # Sort in descending order + idx = eigenvalues.argsort()[::-1] + eigenvalues = eigenvalues[idx] + eigenvectors = eigenvectors[:,idx] + + # make first eigentor unit length + primary_eigenvector = eigenvectors[:,0] / np.linalg.norm(eigenvectors[:,0]) + pt0 = np.mean(contour_ras, axis=1) + pt0 -= np.array([0, 5]) + pt1 = pt0 + primary_eigenvector * 100 + # plot mask with eigentvector + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots(1,2,figsize=(10, 8)) + # ax[0].imshow(cc_mask, cmap='gray') + # # plot line between pt0 and pt1 + # ax[0].plot([pt0[0], pt1[0]], [pt0[1], pt1[1]], 'r-', linewidth=2) + # plt.show() + + return pt0, pt1 + +if __name__ == "__main__": + + OUTPUT_TO_RAS = False + PLOT = False + + paths_csv = pd.read_csv('/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/network/data/found_labels_with_meta_data_difficult_final.csv', index_col=0) + + + FOUND = False + for subj_id in tqdm(paths_csv.index): + + + # if subj_id != 'b1213f65' and not FOUND: + # print(subj_id, 'skipping') + # continue + + # FOUND = True + + + #subj_id = '04ac873f' + #print(subj_id) + + #subj_id = '099f7f5a' + #label_path = '099f7f5a-norm-cc_up-cropped-labels_v02-l02_merged.mgz' + label_path = paths_csv.loc[subj_id, 'label_merged'] + + try: + cc_label_nib = nib.load(label_path) + except Exception as e: + import pdb; pdb.set_trace() + print(subj_id, 'error', e) + continue + + PC_2d = paths_csv.loc[subj_id, 'PC_center_r':'PC_center_s'].to_numpy().astype(float)[1:] + AC_2d = paths_csv.loc[subj_id, 'AC_center_r':'AC_center_s'].to_numpy().astype(float)[1:] + + + cc_mask = cc_label_nib.get_fdata() == 192 + cc_mask = cc_mask[cc_mask.shape[0]//2] + + contour, anterior_endpoint_idx, posterior_endpoint_idx = get_endpoints(cc_mask, AC_2d, PC_2d, cc_label_nib.header.get_zooms()[1], return_coordinates=False) + contour_as = convert_to_ras(contour, cc_label_nib.affine) + ac_pt_as = convert_to_ras(AC_2d, cc_label_nib.affine) + pc_pt_as = convert_to_ras(PC_2d, cc_label_nib.affine) + + np.save(f'./contours/{subj_id}_contour_as.npy', contour_as) + np.save(f'./contours/{subj_id}_ac_pt_as.npy', ac_pt_as) + np.save(f'./contours/{subj_id}_pc_pt_as.npy', pc_pt_as) + continue + + #### contour to ACPC standard #### + contour_acpc, ac_pt_acpc, pc_pt_acpc, rotate_back_acpc = transform_to_acpc_standard(contour_as, ac_pt_as, pc_pt_as) + + + import matplotlib.pyplot as plt + + print(subj_id) + + # fig, ax = plt.subplots(1,1,figsize=(5, 4)) + # ax.plot(contour_acpc[0], contour_acpc[1], 'b-', label='Contour ACPC') + # ax.plot(ac_pt_acpc[0], ac_pt_acpc[1], 'gx', markersize=8, label='AC (ACPC)') + # ax.plot(pc_pt_acpc[0], pc_pt_acpc[1], 'rx', markersize=8, label='PC (ACPC)') + + # ax.plot(contour_ras[0], contour_ras[1], 'y-', label='Contour RAS') + # ax.plot(ac_pt_ras[0], ac_pt_ras[1], 'gx', markersize=8, label='AC (RAS)') + # ax.plot(pc_pt_ras[0], pc_pt_ras[1], 'rx', markersize=8, label='PC (RAS)') + + # # transform back using rotate_back_acpc + # contour_ras_back = rotate_back_acpc(contour_acpc) + # ax.plot(contour_ras_back[0], contour_ras_back[1], 'g-', label='Contour RAS back') + + # ax.set_title(subj_id) + # ax.axis('equal') + # ax.legend() + # plt.show() + # plt.close() + + + + + # fig, ax = plt.subplots(1,1,figsize=(5, 4)) + # #ax.imshow(cc_mask, cmap='gray') + + # image_path = paths_csv.loc[subj_id, 'image_orig_up'] + # image_nib = nib.load(image_path) + # image = image_nib.get_fdata() + # ax.imshow(image[127][::-1], cmap='gray')#, vmin=100, vmax=256) + # #ax.imshow(cc_mask[::-1], cmap='heat', alpha=0.5) + # #contour_acpc[:,1] = contour_acpc[:,1][::-1] + # subdivide_contour(contour_acpc, area_weights=[1/6, 1/2, 2/3, 3/4], plot=PLOT, ax=ax, plot_transform=rotate_back_acpc) + # ax.plot(contour[1], image_nib.shape[2] - contour[0], 'y-', linewidth=3) + # # invert y axis + # ax.invert_yaxis() + # plt.show() + # plt.close() + + + fig, ax = plt.subplots(2,3,figsize=(12, 8), sharex=True, sharey=True) + + + # Aboitiz scheme + subdivided_contour = subdivide_contour(contour_as, area_weights=[1/3, 2/3, 4/5], plot=PLOT, ax=ax[0,0], oriented=False, hline_anchor=ac_pt_as) + ax[0,0].set_title('Aboitiz') + + # Witelson scheme + subdivided_contour = subdivide_contour(contour_as, area_weights=[1/3, 1/2, 2/3, 4/5], plot=PLOT, ax=ax[0,1], oriented=False, hline_anchor=ac_pt_as) + ax[0,1].set_title('Witelson') + + # Jaenecke + subdivided_contour = subdivide_contour(contour_acpc, area_weights=[1/3, 1/2, 2/3, 4/5], plot=PLOT, ax=ax[0,2], oriented=True, plot_transform=rotate_back_acpc, hline_anchor=ac_pt_acpc) + ax[0,2].set_title('Jäncke') + + # Hofer-Frahm + subdivided_contour = subdivide_contour(contour_as, area_weights=[1/6, 1/2, 2/3, 3/4], plot=PLOT, ax=ax[1,0], oriented=False, hline_anchor=ac_pt_as) + ax[1,0].set_title('Hofer-Frahm') + + + + # Hofer-Frahm + Jaenecke + # subdivided_contour = subdivide_contour(contour_acpc, area_weights=[1/6, 1/2, 2/3, 3/4], plot=PLOT, ax=ax[1,1], oriented=True, plot_transform=rotate_back_acpc, hline_anchor=ac_pt_acpc) + # ax[1,1].set_title('Hofer-Frahm + Jaenecke') + + subdivided_contour = hampel_subdivide_contour(contour_as, num_rays=4, plot=PLOT, ax=ax[1,1]) + ax[1,1].set_title('Hampel') + + + + pt0, pt1 = get_primary_eigenvector(contour_as) + contour_eigen, pt0_eigen, pt1_eigen, rotate_back_eigen = transform_to_acpc_standard(contour_as, pt0, pt1) + ac_pt_eigen, _, _, _ = transform_to_acpc_standard(ac_pt_as[:, None], pt0, pt1) + ac_pt_eigen = ac_pt_eigen[:, 0] + # fig, ax = plt.subplots(1,1,figsize=(5, 4)) + # ax.plot(contour_eigen[0], contour_eigen[1], 'b-', label='Contour Eigen') + # ax.plot(pt0_eigen[0], pt0_eigen[1], 'gx', markersize=8, label='AC (Eigen)') + # ax.plot(pt1_eigen[0], pt1_eigen[1], 'rx', markersize=8, label='PC (Eigen)') + # ax.plot(contour_ras[0], contour_ras[1], 'y-', label='Contour RAS') + # ax.plot(ac_pt_ras[0], ac_pt_ras[1], 'gx', markersize=8, label='AC (RAS)') + # ax.plot(pc_pt_ras[0], pc_pt_ras[1], 'rx', markersize=8, label='PC (RAS)') + # ax.axis('equal') + # ax.legend() + # plt.show() + # plt.close() + subdivided_contour = subdivide_contour(contour_eigen, area_weights=[1/5, 2/5, 3/5, 4/5], plot=PLOT, ax=ax[1,2], oriented=True, plot_transform=rotate_back_eigen, hline_anchor=ac_pt_eigen) + ax[1,2].set_title('mri_cc') + + + + + + try: + midline_length, thickness, curvature, midline_equidistant, levelpaths = cc_thickness(contour_as.T, anterior_endpoint_idx, posterior_endpoint_idx) + except Exception as e: + contour_as += np.random.randn(contour_as.shape[0], contour_as.shape[1])*0.0001 + midline_length, thickness, curvature, midline_equidistant, levelpaths = cc_thickness(contour_as.T, anterior_endpoint_idx, posterior_endpoint_idx) + print('Successfully computed thickness after adding noise') + + #contour_as = contour_as.T + + + + + + + + plt.tight_layout() + # make axis equal + for a in ax.flatten(): + a.set_aspect('equal', adjustable='box') + a.axis('off') + + + # first two rows + # for a in ax[0:2, :].flatten(): + # a.scatter(ac_pt_as[0], ac_pt_as[1], color='red', s=100, marker='x') + # a.scatter(pc_pt_as[0], pc_pt_as[1], color='blue', s=100, marker='x') + + ax[0,0].invert_xaxis() + + plt.savefig(f'/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/cc_pipeline/subdivision_plots/cc_subdivisions_{subj_id}.png', dpi=300, bbox_inches='tight') + plt.show() + plt.close() + + + fig, ax = plt.subplots(1,1,figsize=(5, 4)) + areas, split_contours = subsegment_midline_orthogonal(midline_equidistant, [1/6, 1/2, 2/3, 3/4], contour_as, plot=True, ax=ax) + ax.invert_xaxis() + ax.set_title('Midline subdivision - Hofer-Frahm ratios') + ax.axis('equal') + ax.axis('off') + plt.savefig(f'/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/cc_pipeline/subdivision_plots/cc_subdivisions_{subj_id}_midline.png', dpi=300, bbox_inches='tight') + plt.show() + plt.close() + \ No newline at end of file diff --git a/CorpusCallosum/shape/cc_thickness.py b/CorpusCallosum/shape/cc_thickness.py new file mode 100644 index 00000000..24df8374 --- /dev/null +++ b/CorpusCallosum/shape/cc_thickness.py @@ -0,0 +1,368 @@ +import sys +import os + +import numpy as np +from lapy import TriaMesh, Solver +from lapy.diffgeo import compute_rotated_f +import meshpy.triangle as triangle +import scipy.interpolate + + +class HiddenPrints: + def __enter__(self): + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, 'w') + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout.close() + sys.stdout = self._original_stdout + + +def compute_curvature(path): + # compute curvature by computing edge angles + edges = np.diff(path, axis=0) + angles = np.arctan2(edges[:,1], edges[:,0]) + # compute angle differences between consecutive edges + angle_diffs = np.diff(angles) + # wrap angles to [-pi, pi] + angle_diffs = np.mod(angle_diffs + np.pi, 2*np.pi) - np.pi + return angle_diffs + + +def convert_to_ras(contour, vox2ras_matrix, get_parameters=False): + # converting to AS (no left-right dimension), out of plane movement is ignores, so we only do scaling, axes swapping and flipping - no rotation + # translation is ignored + if contour.shape[0] == 2: + # get only axis swaps + axis_swaps = np.round(vox2ras_matrix[:3,:3], 0) + permutation = np.argwhere(axis_swaps != 0)[:,1] + assert(len(permutation) == 3) + + idx_superior = np.argwhere(permutation == 2) + idx_anterior = np.argwhere(permutation == 1) + + swap_axes = idx_anterior > idx_superior + if swap_axes: + # swap anterior and superior + contour = contour[[1,0]] + + # determine if axis were reversed + superior_reversed = (axis_swaps[2,:] == -1).any() + anterior_reversed = (axis_swaps[1,:] == -1).any() + + + # flip axes if necessary + if superior_reversed: + contour[1] = -contour[1] + if anterior_reversed: + contour[0] = -contour[0] + + # get scaling by getting length of three column vectors + scaling = np.linalg.norm(vox2ras_matrix[:3,:3], axis=0) + + # apply transformation + contour = (contour.T / scaling[1:]).T + + if get_parameters: + return contour, anterior_reversed, superior_reversed, swap_axes + else: + return contour + + # # Add a third dimension (z) with 0 and a fourth dimension (homogeneous coordinate) with 1 + elif contour.shape[0] == 3: + contour_homogeneous = np.vstack([contour, np.ones(contour.shape[1])]) + + # Apply the transformation + contour = (vox2ras_matrix @ contour_homogeneous)[:3, :] + return contour + + +def set_contour_zero_idx(contour, idx, anterior_endpoint_idx, posterior_endpoint_idx): + contour = np.roll(contour, -idx, axis=0) + anterior_endpoint_idx = (anterior_endpoint_idx - idx) % contour.shape[0] + posterior_endpoint_idx = (posterior_endpoint_idx - idx) % contour.shape[0] + return contour, anterior_endpoint_idx, posterior_endpoint_idx + + +def find_closest_edge(point, contour): + """Find the index of the edge closest to the given point. + + Args: + point: 2D point coordinates + contour: Array of contour points (N x 2) + + Returns: + Index of the closest edge + """ + edges_start = contour[:-1, :2] # N-1 x 2 + edges_end = contour[1:, :2] # N-1 x 2 + edges_vec = edges_end - edges_start # N-1 x 2 + + # Calculate projection coefficient for all edges at once + # (p-a)·(b-a) / |b-a|² + edge_lengths_sq = np.sum(edges_vec * edges_vec, axis=1) + # Avoid division by zero for degenerate edges + valid_edges = edge_lengths_sq > 1e-10 + t = np.zeros(len(edges_start)) + t[valid_edges] = np.sum((point - edges_start[valid_edges]) * edges_vec[valid_edges], axis=1) / edge_lengths_sq[valid_edges] + t = np.clip(t, 0, 1) # Clamp to edge endpoints + + # Get closest points on all edges + closest_points = edges_start + t[:,None] * edges_vec + + # Calculate distances to all edges + distances = np.linalg.norm(point - closest_points, axis=1) + + # Return index of closest edge + return np.argmin(distances) + + +def insert_point_to_contour(contour_with_thickness, point, thickness_value, get_index=False): + """Insert a point and its thickness value into the contour. + + Args: + contour_with_thickness: List containing [contour_points, thickness_values] + point: 2D point to insert + thickness_value: Thickness value corresponding to the point + + Returns: + Updated contour_with_thickness + """ + # Find closest edge for the point + edge_idx = find_closest_edge(point, contour_with_thickness[0]) + + # Insert point between edge endpoints + contour_with_thickness[0] = np.insert(contour_with_thickness[0], edge_idx+1, point, axis=0) + contour_with_thickness[1] = np.insert(contour_with_thickness[1], edge_idx+1, thickness_value) + + if get_index: + return contour_with_thickness, edge_idx+1 + else: + return contour_with_thickness + + +def make_mesh_from_contour(contour_2d, max_volume=0.5, min_angle=25, verbose=False): + + facets = np.vstack((np.arange(len(contour_2d)) , ((np.arange(len(contour_2d))+1) % len(contour_2d)))).T + + + # plot vertices and facets + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots(figsize=(10, 8)) + # ax.scatter(contour_2d[:,0], contour_2d[:,1], label='Contour') + # ax.plot(contour_2d[:,0], contour_2d[:,1], 'k-', label='Contour') + # ax.plot(contour_2d[facets[:,0],0], contour_2d[facets[:,0],1], 'r-', label='Facets') + # plt.show() + + # use meshpy to create mesh + info = triangle.MeshInfo() + info.set_points(contour_2d) + info.set_facets(facets) + # NOTE: crashes if contour has duplicate points !! + mesh = triangle.build(info, max_volume=max_volume, min_angle=min_angle, verbose=verbose) + + + mesh_points = np.array(mesh.points) + mesh_trias = np.array(mesh.elements) + + return mesh_points, mesh_trias + + +def cc_thickness(contour_2d, anterior_endpoint_idx, posterior_endpoint_idx, n_points=100): + + # standardize contour indices, to get consistent levelpath directions + contour_2d, anterior_endpoint_idx, posterior_endpoint_idx = set_contour_zero_idx(contour_2d, anterior_endpoint_idx, anterior_endpoint_idx, posterior_endpoint_idx) + + mesh_points, mesh_trias = make_mesh_from_contour(contour_2d) + + + # plot mesh points with index next to point + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots(figsize=(10, 8)) + # ax.scatter(mesh_points[:,0], mesh_points[:,1], label='Mesh Points') + # for i in range(len(mesh_points)): + # ax.text(mesh_points[i,0], mesh_points[i,1], str(i), fontsize=7) + # plt.show() + + + # make points 3D by appending z=0 + mesh_points3d = np.append(mesh_points,np.zeros((mesh_points.shape[0],1)),axis=1) + + # compute poisson + with HiddenPrints(): + tria = TriaMesh(mesh_points3d, mesh_trias) + # extract boundary curve + bdr = np.array(tria.boundary_loops()[0]) + + # find index of endpoints in bdr list + iidx1=np.where(bdr==anterior_endpoint_idx)[0][0] + iidx2=np.where(bdr==posterior_endpoint_idx)[0][0] + + # create boundary condition (0 at endpoints, -1 on one side, 1 on the other): + if iidx1 > iidx2: + tmp= iidx2 + iidx2 = iidx1 + iidx1 = tmp + dcond = np.ones(bdr.shape) + dcond[iidx1] =0 + dcond[iidx2] =0 + dcond[iidx1+1:iidx2] = -1 + + + # Extract path + with HiddenPrints(): + fem = Solver(tria) + vfunc = fem.poisson(0,(bdr,dcond)) + level = 0 + midline_equidistant, midline_length = tria.level_path(vfunc, level, n_points=n_points+2) + midline_equidistant = midline_equidistant[:,:2] + + + # try: + with HiddenPrints(): + gf = compute_rotated_f(tria,vfunc) + # except Exception as e: + # Lot contour and path + # import matplotlib.pyplot as plt + # import matplotlib.tri as tri + # fig, ax = plt.subplots(figsize=(10, 8)) + # # Plot contours + # ax.plot(contour_2d[:,0], contour_2d[:,1], 'k-', label='Contour', marker='o', markersize=3) + # ax.plot(midline_equidistant[:,0], midline_equidistant[:,1], 'g-', label='Level0', marker='o', markersize=2) + # # plot mesh + # mtpltlb_tria = tri.Triangulation(tria.v[:,0], tria.v[:,1], triangles=tria.t) + # ax.triplot(mtpltlb_tria, 'k-', alpha=0.2, linewidth=0.5) + # # Plot final endpoint estimates + # ax.plot(contour_2d[:,0][anterior_endpoint_idx], contour_2d[:,1][anterior_endpoint_idx], 'r*', + # markersize=15, label='Final estimate') + # ax.plot(contour_2d[:,0][posterior_endpoint_idx], contour_2d[:,1][posterior_endpoint_idx], 'r*', + # markersize=15, label='Final estimate') + # ax.legend() + # #ax.set_title(f'Subject: {subj_id}') + # plt.show() + + # interpolate midline to get levels to evaluate + gf_interp = scipy.interpolate.griddata(tria.v[:,0:2], gf, midline_equidistant[:,0:2], method='cubic') + + # get levels to evaluate + #level_length = tria.level_length(gf, gf_interp) + + levelpaths = [] + levelpath_lengths = [] + levelpath_tria_idx = [] + + contour_with_thickness = [contour_2d.copy(), np.full(contour_2d.shape[0], np.nan)] + for i in range(1,n_points+1): + level = gf_interp[i] + # levelpath starts at index zero + lvlpath, lvlpath_length, tria_idx = tria.level_path(gf, level, get_tria_idx=True) + + levelpaths.append(lvlpath) + levelpath_lengths.append(lvlpath_length) + levelpath_tria_idx.append(tria_idx) + + levelpath_start = lvlpath[0,:2] + levelpath_end = lvlpath[-1,:2] + + contour_with_thickness, inserted_idx_start = insert_point_to_contour(contour_with_thickness, levelpath_start, lvlpath_length, get_index=True) + contour_with_thickness, inserted_idx_end = insert_point_to_contour(contour_with_thickness, levelpath_end, lvlpath_length, get_index=True) + + # keep track of start and end indices + if inserted_idx_start <= anterior_endpoint_idx: + anterior_endpoint_idx += 1 + if inserted_idx_end <= anterior_endpoint_idx: + anterior_endpoint_idx += 1 + + if inserted_idx_start >= posterior_endpoint_idx: + posterior_endpoint_idx += 1 + if inserted_idx_end >= posterior_endpoint_idx: + posterior_endpoint_idx += 1 + + + + + + + # import matplotlib.pyplot as plt + + # fig, ax = plt.subplots(figsize=(10, 8)) + # cont = contour_with_thickness[0] + # ax.plot(cont[:,0], cont[:,1], 'k-', label='Contour', marker='o', markersize=3) + # ax.scatter(cont[:,0][anterior_endpoint_idx], cont[:,1][anterior_endpoint_idx], c='r', label='Anterior Endpoint', marker='o') + # ax.scatter(cont[:,0][posterior_endpoint_idx], cont[:,1][posterior_endpoint_idx], c='b', label='Posterior Endpoint', marker='o') + # ax.legend() + # plt.show() + + # thickness_measurement_points_top = [] + # thickness_measurement_points_bottom = [] + # for i in range(len(levelpaths)): + # thickness_measurement_points_top.append(levelpaths[i][0,:2]) + # thickness_measurement_points_bottom.append(levelpaths[i][-1,:2]) + + # thickness_measurement_points_top = np.array(thickness_measurement_points_top) + # thickness_measurement_points_bottom = np.array(thickness_measurement_points_bottom) + # thickness_measurement_points = np.concatenate([thickness_measurement_points_top, thickness_measurement_points_bottom], axis=0).T + + # # Create a figure with subplots + # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) + + # # Plot 1: Contour + # ax1.plot(contour_2d[:,0], -contour_2d[:,1], 'b-', linewidth=2, label='Contour') + # ax1.set_title('Corpus Callosum Contour') + # ax1.set_xlabel('X') + # ax1.set_ylabel('Y') + # ax1.axis('equal') + # ax1.invert_yaxis() + # ax1.legend() + + # # Plot 2: Thickness measurement points + # print(thickness_measurement_points.shape) + # ax2.plot(thickness_measurement_points[0, :100], -thickness_measurement_points[1, :100], 'ro', markersize=3, label='Thickness Points (start)') + # ax2.plot(thickness_measurement_points[0, 100:], -thickness_measurement_points[1, 100:], 'go', markersize=3, label='Thickness Points (end)') + # ax2.set_title('Thickness Measurement Points') + # ax2.set_xlabel('X') + # ax2.set_ylabel('Y') + # ax2.axis('equal') + # ax2.invert_yaxis() + # ax2.legend() + # plt.show() + + + # get curvature of path3d_resampled + curvature = compute_curvature(midline_equidistant) + out_curvature = np.abs(np.degrees(np.mean(curvature))) / len(curvature) + # print(f'Curvature: {out_curvature:.2f}') + # print(f'Length of midline: ', f'{midline_length:.2f}') + # print(f'Thickness: {np.mean(levelpath_lengths):.2f}') + + + # import matplotlib.pyplot as plt + # import matplotlib.tri as tri + # fig, ax = plt.subplots(figsize=(5, 4)) + # mtpltlb_tria = tri.Triangulation(tria.v[:,0], tria.v[:,1], triangles=tria.t) + # triang = plt.tricontourf(mtpltlb_tria, gf, cmap='autumn', alpha=0.2) + # ax.plot(midline_equidistant[:,0], midline_equidistant[:,1], 'r-', label=f'Levelsets')#, marker='o', markersize=2) + # #ax.plot(contour_2d[:,0], contour_2d[:,1], 'k-', label='Contour', alpha=0.6) + + # for i in range(len(levelpaths)): + # if levelpaths[i] is not None: + # ax.plot(levelpaths[i][:,0], levelpaths[i][:,1], 'r-', marker='o', markersize=0) # , label=f'Level {levelpath_lengths[i]:.2f}' + # ax.plot(midline_equidistant[:,0], midline_equidistant[:,1], '-', label='Midline', alpha=1, color='darkgoldenrod')#, marker='o', markersize=2) + + # #plt.colorbar(colorscale, label='Level values') + # # plot mesh + # ax.triplot(tria.v[:,0], tria.v[:,1], tria.t, 'k-', alpha=0.2, linewidth=0.5) + # #ax.scatter(path3d_resampled[99,0], path3d_resampled[99,1], c='g', s=20) + + # ax.set_aspect('equal') + # #plt.title('Levelpath on rotated Poisson') + # plt.legend() + # # invert x axis + # ax.invert_xaxis() + # plt.tight_layout() + # plt.axis('off') + # plt.savefig(f'levelsets.png', dpi=300, bbox_inches='tight') + # plt.show() + + return midline_length, np.mean(levelpath_lengths), out_curvature, midline_equidistant, levelpaths, contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx + diff --git a/CorpusCallosum/shape/resample_poly.py b/CorpusCallosum/shape/resample_poly.py new file mode 100644 index 00000000..e9236dee --- /dev/null +++ b/CorpusCallosum/shape/resample_poly.py @@ -0,0 +1,65 @@ +import numpy as np + +def resample_polygon(xy: np.ndarray, n_points: int = 100) -> np.ndarray: + # Cumulative Euclidean distance between successive polygon points. + # This will be the "x" for interpolation + d = np.cumsum(np.r_[0, np.sqrt((np.diff(xy, axis=0) ** 2).sum(axis=1))]) + + # get linearly spaced points along the cumulative Euclidean distance + d_sampled = np.linspace(0, d.max(), n_points) + + # interpolate x and y coordinates + xy_interp = np.c_[ + np.interp(d_sampled, d, xy[:, 0]), + np.interp(d_sampled, d, xy[:, 1]), + ] + + return xy_interp + +def iterative_resample_polygon(xy: np.ndarray, n_points: int = 100, n_iter: int = 3) -> np.ndarray: + # resample multiple times to numerically stabilize the result to be truly equidistant + xy_resampled = resample_polygon(xy, n_points) + for _ in range(n_iter-1): + xy_resampled = resample_polygon(xy_resampled, n_points) + return xy_resampled + + +if __name__ == "__main__": + import time + import matplotlib.pyplot as plt + + coords = [ + {'x': 354.0, 'y': 424.0}, {'x': 318.0, 'y': 455.0}, {'x': 299.0, 'y': 458.0}, {'x': 284.0, 'y': 464.0}, {'x': 250.0, 'y': 490.0}, + {'x': 229.0, 'y': 492.0}, {'x': 204.0, 'y': 484.0}, {'x': 187.0, 'y': 469.0}, {'x': 176.0, 'y': 449.0}, {'x': 164.0, 'y': 435.0}, + {'x': 119.0, 'y': 274.0}, {'x': 121.0, 'y': 264.0}, {'x': 118.0, 'y': 249.0}, {'x': 118.0, 'y': 224.0}, {'x': 121.0, 'y': 209.0}, + {'x': 130.0, 'y': 194.0}, {'x': 138.0, 'y': 159.0}, {'x': 147.0, 'y': 139.0}, {'x': 155.0, 'y': 112.0}, {'x': 170.0, 'y': 89.0}, + {'x': 190.0, 'y': 67.0}, {'x': 220.0, 'y': 54.0}, {'x': 280.0, 'y': 47.0}, {'x': 310.0, 'y': 55.0}, {'x': 330.0, 'y': 56.0}, + {'x': 345.0, 'y': 60.0}, {'x': 355.0, 'y': 67.0}, {'x': 367.0, 'y': 80.0}, {'x': 375.0, 'y': 84.0}, {'x': 382.0, 'y': 95.0}, + ] + + # construct numpy array from list of dicts + xy = np.array([(c['x'], c['y']) for c in coords]) + + n_points = 30 + # resample polygon + print(f"Resampling polygon with {len(xy)} points to {n_points} points") + start_time = time.time() + xy_resampled = iterative_resample_polygon(xy, n_points, n_iter=20) + end_time = time.time() + print(f"Time taken: {end_time - start_time:.2f} seconds") + + # plot result + fig, ax = plt.subplots(figsize=(7,14)) + ax.scatter(xy[:, 1], xy[:, 0], marker='o', s=150, label='original', color='black') + ax.scatter(xy_resampled[:, 1], xy_resampled[:, 0], label='resampled', color='red') + ax.set_aspect(1) + ax.invert_yaxis() + plt.legend() + plt.show() + + # Calculate distances between consecutive vertices + distances = np.sqrt(np.sum((xy_resampled[1:] - xy_resampled[:-1])**2, axis=1)) + print('Distance between consecutive vertices:', distances) + + + diff --git a/CorpusCallosum/transforms/localization_transforms.py b/CorpusCallosum/transforms/localization_transforms.py new file mode 100644 index 00000000..30891eb8 --- /dev/null +++ b/CorpusCallosum/transforms/localization_transforms.py @@ -0,0 +1,76 @@ +from monai.transforms import RandomizableTransform, MapTransform +import numpy as np + +class CropAroundACPCFixedSize(RandomizableTransform, MapTransform): + """ + Crop around AC and PC with fixed size + """ + + def __init__(self, keys, fixed_size: tuple[int, int], allow_missing_keys: bool = False, random_translate: float = 0) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self) + self.random_translate = random_translate + self.fixed_size = fixed_size + + + def __call__(self, data): + d = dict(data) + + for key in self.keys: + if key not in d.keys() and self.allow_missing_keys: + continue + + # Get AC and PC centers + pc_center = d['PC_center'] + ac_center = d['AC_center'] + + # calculate center point between AC and PC + center_point = ((ac_center + pc_center) / 2).astype(int) + + # Calculate voxel padding based on mm padding + voxel_padding_x = self.fixed_size[0] // 2 + voxel_padding_y = self.fixed_size[1] // 2 + + # Add random translation if specified + if self.random_translate > 0: + random_translate = np.random.randint(-self.random_translate, + self.random_translate, size=2) + else: + random_translate = (0,0) + + + + # Calculate crop boundaries with padding and random translation + crop_left = center_point[1] - voxel_padding_x + random_translate[0] + crop_right = center_point[1] + voxel_padding_x + random_translate[0] + crop_top = center_point[2] - voxel_padding_y + random_translate[1] + crop_bottom = center_point[2] + voxel_padding_y + random_translate[1] + + # Ensure crop boundaries are within image + #img_shape = d['image'].shape[2:] # Get spatial dimensions + # crop_left = max(0, crop_left) + # crop_right = min(img_shape[0], crop_right) + # crop_top = max(0, crop_top) + # crop_bottom = min(img_shape[1], crop_bottom) + + # raise error if crop boundaries are out of image + if crop_left < 0 or crop_right > d['image'].shape[2] or crop_top < 0 or crop_bottom > d['image'].shape[3]: + raise ValueError("Crop boundaries are out of image") + + # Apply crop to image + for key in self.keys: + if key not in d.keys() and self.allow_missing_keys: + continue + + d[key] = d[key][:, :, crop_left:crop_right, crop_top:crop_bottom] + + # Update point coordinates relative to cropped image + d['PC_center'][1:] = d['PC_center'][1:] - [crop_left, crop_top] + d['AC_center'][1:] = d['AC_center'][1:] - [crop_left, crop_top] + + + d['crop_left'] = crop_left + d['crop_right'] = crop_right + d['crop_top'] = crop_top + d['crop_bottom'] = crop_bottom + return d diff --git a/CorpusCallosum/transforms/segmentation_transforms.py b/CorpusCallosum/transforms/segmentation_transforms.py new file mode 100644 index 00000000..a296e942 --- /dev/null +++ b/CorpusCallosum/transforms/segmentation_transforms.py @@ -0,0 +1,105 @@ +from monai.transforms import RandomizableTransform, MapTransform +import numpy as np + + +class CropAroundACPC(RandomizableTransform, MapTransform): + """ + Crop around AC and PC + """ + + def __init__(self, keys, allow_missing_keys: bool = False, padding_mm: float = 10, random_translate: float = 0) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob=1, do_transform=True) + self.padding_mm = padding_mm + self.random_translate = random_translate + + def __call__(self, data): + d = dict(data) + + + if 'AC_center_original' not in d: + d['AC_center_original'] = d['AC_center'].copy() + if 'PC_center_original' not in d: + d['PC_center_original'] = d['PC_center'].copy() + + if self.random_translate > 0: + random_translate = np.random.randint(-self.random_translate, self.random_translate, size=2) + else: + random_translate = (0,0,0) + + for key in self.keys: + if key not in d.keys() and self.allow_missing_keys: + continue + + pc_center = d['PC_center'] + ac_center = d['AC_center'] + + # 'PC_center': array([ 2., 139., 143.], dtype=float32), 'AC_center': array([ 2., 128., 168.] + + ac_pc_bottomleft = (np.min([ac_center[1], pc_center[1]]).astype(int), np.min([ac_center[2], pc_center[2]]).astype(int)) + ac_pc_topright = (np.max([ac_center[1], pc_center[1]]).astype(int), np.max([ac_center[2], pc_center[2]]).astype(int)) + + + voxel_padding = round(self.padding_mm / d['res']) + + crop_left = ac_pc_bottomleft[0]-int(voxel_padding*1.5)+random_translate[0] + crop_right = ac_pc_topright[0]+voxel_padding//2+random_translate[0] + crop_top = ac_pc_bottomleft[1]-voxel_padding+random_translate[1] + crop_bottom = ac_pc_topright[1]+voxel_padding+random_translate[1] + + + d['to_pad'] = crop_left, d[key].shape[2]-crop_right, crop_top, d[key].shape[3]-crop_bottom + d[key] = d[key][:, :, crop_left:crop_right, crop_top:crop_bottom] + + + #d[key] = d[key][:, d[key].shape[1]//2-voxel_padding:d[key].shape[2]//2+voxel_padding] + + #print('cropped', d[key].shape, 'for key', key) + + return d + +class CropAroundACPCtrack(CropAroundACPC): + """ + Same as crop around ACPC but also adjusts AC_center and PC_center accordingly + + + """ + + def __call__(self, data): + + + # First call parent class to get cropped image + d = super().__call__(data) + + # Get the crop coordinates that were used + pad_left, pad_right, pad_top, pad_bottom = d['to_pad'] + + # Adjust AC and PC center coordinates based on cropping + if 'AC_center' in d: + d['AC_center'][1] = d['AC_center_original'][1] - pad_left.item() + d['AC_center'][2] = d['AC_center_original'][2] - pad_top.item() + + if 'PC_center' in d: + d['PC_center'][1] = d['PC_center_original'][1] - pad_left.item() + d['PC_center'][2] = d['PC_center_original'][2] - pad_top.item() + + return d + + + +class UncropAroundACPC(MapTransform): + """ + Uncrop around AC and PC - reverses CropAroundACPC transform by padding back to original size + """ + + def __init__(self, keys, allow_missing_keys: bool = False, padding_mm: float = 10) -> None: + super().__init__(keys, allow_missing_keys) + self.padding_mm = padding_mm + + def __call__(self, data): + pad_left, pad_right, pad_top, pad_bottom = d['to_pad'] + + # Pad back to original size + d[key] = np.pad(d[key], ((0,0), (0,0), (pad_left.item(), pad_right.item()), (pad_top.item(), pad_bottom.item())), mode='constant', constant_values=0) + + return d \ No newline at end of file diff --git a/CorpusCallosum/visualization/visualization.py b/CorpusCallosum/visualization/visualization.py new file mode 100644 index 00000000..ae4f5fde --- /dev/null +++ b/CorpusCallosum/visualization/visualization.py @@ -0,0 +1,276 @@ +from pathlib import Path +import numpy as np +import matplotlib.pyplot as plt +from typing import Tuple, List, Union +import nibabel as nib +from scipy.ndimage import affine_transform + +#from mapping_helpers import apply_transform_and_map_volume + + + +def plot_standardized_space(ax_row, vol, ac_coords, pc_coords): + """Plot standardized space visualization across three views. + + Args: + ax_row: Row of axes to plot on (should be length 3) + vol: Volume data to visualize + ac_coords: AC coordinates in standardized space + pc_coords: PC coordinates in standardized space + """ + ax_row[0].set_title('Standardized') + + # Axial view + ax_row[0].scatter(ac_coords[2], ac_coords[1], color='red', marker='x') + ax_row[0].scatter(pc_coords[2], pc_coords[1], color='blue', marker='x') + ax_row[0].imshow(vol[vol.shape[0]//2], cmap='gray') + + # Sagittal view + ax_row[1].scatter(ac_coords[2], ac_coords[0], color='red', marker='x') + ax_row[1].scatter(pc_coords[2], pc_coords[0], color='blue', marker='x') + ax_row[1].imshow(vol[:,vol.shape[1]//2], cmap='gray') + + # Coronal view + ax_row[2].scatter(ac_coords[1], ac_coords[0], color='red', marker='x') + ax_row[2].scatter(pc_coords[1], pc_coords[0], color='blue', marker='x') + ax_row[2].imshow(vol[:,:,vol.shape[2]//2], cmap='gray') + + +def visualize_coordinate_spaces(orig, upright, standardized, + ac_coords_orig, pc_coords_orig, + ac_coords_3d, pc_coords_3d, + ac_coords_standardized, pc_coords_standardized, + output_dir): + """ + Visualize the AC and PC coordinates in different coordinate spaces for testing/debugging. + + Args: + orig: Original image volume + vol: Volume in fsaverage space + vol2: Volume after nodding correction + vol3: Volume after translation + ac_coords_*: AC coordinates in different spaces + pc_coords_*: PC coordinates in different spaces + output_dir: Directory to save visualization + """ + fig, ax = plt.subplots(3, 4) + ax = ax.T + + # Original space - using plot_standardized_space + plot_standardized_space(ax[0], orig.get_fdata(), ac_coords_orig, pc_coords_orig) + ax[0,0].set_title('Orig') + + # Fsaverage space + plot_standardized_space(ax[1], upright, ac_coords_3d, pc_coords_3d) + ax[1,0].set_title('Fsaverage') + + # Standardized space + plot_standardized_space(ax[2], standardized, ac_coords_standardized, pc_coords_standardized) + ax[2,0].set_title('Standardized') + # Format all subplots + for a in ax.flatten(): + a.set_aspect('equal', adjustable='box') + a.axis('off') + + plt.savefig(Path(output_dir) / "ac_pc_spaces.png", dpi=300, bbox_inches='tight') + plt.show() + plt.close() + + +# def map_image_to_standard_space(orig: nib.MGHImage, +# ac_coords_3d: np.ndarray, +# pc_coords_3d: np.ndarray, +# orig_fsaverage_vox2vox: np.ndarray, +# output_dir: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: +# """Maps an input image to standard space using AC-PC alignment. + +# This function performs the following transformations: +# 1. Maps the image from original space to fsaverage space +# 2. Applies nodding correction +# 3. Translates AC point to center + +# Args: +# orig: Original input image as MGHImage +# ac_coords_3d: Anterior commissure coordinates in 3D +# pc_coords_3d: Posterior commissure coordinates in 3D +# orig_fsaverage_vox2vox: Transform matrix from original to fsaverage space +# output_dir: Directory to save intermediate visualization volumes + +# Returns: +# Tuple containing: +# - orig_to_standardized_vox2vox: Final transformation matrix +# - ac_coords_standardized: AC coordinates in standardized space +# - pc_coords_standardized: PC coordinates in standardized space +# - ac_coords_orig: Original AC coordinates +# - pc_coords_orig: Original PC coordinates +# """ +# # ... existing code ... + +# # Generate intermediate volumes for visualization +# vol = apply_transform_and_map_volume( +# orig.get_fdata(), +# orig_fsaverage_vox2vox, +# orig.affine, +# orig.header, +# Path(output_dir) / "inv_fsaverage_orig_vox2vox.mgz" +# ) + +# vol3 = apply_transform_and_map_volume( +# vol2, +# ac_to_center_translation, +# orig.affine, +# orig.header, +# Path(output_dir) / "translation.mgz" +# ) + +# # Visualize coordinate spaces +# visualize_coordinate_spaces( +# orig, vol, vol2, vol3, +# ac_coords_orig, pc_coords_orig, +# ac_coords_3d, pc_coords_3d, +# ac_coords_standardized, pc_coords_standardized, +# output_dir +# ) + +# return orig_to_standardized_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig + + + + +def plot_contours(transformed: np.ndarray, + split_contours: List[np.ndarray], + split_contours_hofer_frahm: List[np.ndarray], + midline_equidistant: np.ndarray, + levelpaths: List[np.ndarray], + output_path: str, + ac_coords: np.ndarray, + pc_coords: np.ndarray, + vox_size: float, + title: str = None) -> None: + """Plots corpus callosum contours and segmentations. + + Creates a figure with three subplots showing: + 1. Midline-based subsegmentation + 2. Hofer-Frahm segmentation scheme + 3. Midline and levelpaths visualization + + Args: + transformed: The transformed brain image array + split_contours: List of contour arrays for midline-based segmentation + split_contours_hofer_frahm: List of contour arrays for Hofer-Frahm segmentation + midline_equidistant: Array of midline points + levelpaths: List of levelpath arrays + output_dir: Directory to save the output plot + ac_coords: Anterior commissure coordinates + pc_coords: Posterior commissure coordinates + """ + + # scale contour data by vox_size + split_contours = [split_contour * vox_size for split_contour in split_contours] if split_contours is not None else None + split_contours_hofer_frahm = [split_contour * vox_size for split_contour in split_contours_hofer_frahm] if split_contours_hofer_frahm is not None else None + midline_equidistant = midline_equidistant * vox_size + levelpaths = [levelpath * vox_size for levelpath in levelpaths] + + + + NO_PLOTS = 1 + if split_contours is not None: + NO_PLOTS += 1 + if split_contours_hofer_frahm is not None: + NO_PLOTS += 1 + + fig, ax = plt.subplots(1,NO_PLOTS, sharex=True, sharey=True, figsize=(15, 10)) + + PLT_NUM = 0 + + if split_contours is not None: + ax[PLT_NUM].imshow(transformed[transformed.shape[0]//2], cmap='gray') + #ax[0].imshow(cc_mask, cmap='autumn') + ax[PLT_NUM].set_title(title) + for i in range(len(split_contours)): + ax[PLT_NUM].fill(split_contours[i][0,:], -split_contours[i][1,:], color='steelblue', alpha=0.25) + ax[PLT_NUM].plot(split_contours[i][0,:], -split_contours[i][1,:], color='mediumblue', linestyle='dotted', linewidth=0.7) + ax[PLT_NUM].plot(split_contours[0][0,:], -split_contours[0][1,:], color='mediumblue', linewidth=0.7) + ax[PLT_NUM].scatter(ac_coords[1], ac_coords[0], color='red', marker='x') + ax[PLT_NUM].scatter(pc_coords[1], pc_coords[0], color='blue', marker='x') + PLT_NUM += 1 + + if split_contours_hofer_frahm is not None: + + ax[PLT_NUM].imshow(transformed[transformed.shape[0]//2], cmap='gray') + #ax[1].imshow(cc_mask, cmap='autumn') + ax[PLT_NUM].set_title('Hofer-Frahm Jaenecke') + for i in range(len(split_contours_hofer_frahm)): + ax[PLT_NUM].fill(split_contours_hofer_frahm[i][0,:], -split_contours_hofer_frahm[i][1,:], color='steelblue', alpha=0.25) + ax[PLT_NUM].plot([split_contours_hofer_frahm[i][0,0], split_contours_hofer_frahm[i][0,-1]], [-split_contours_hofer_frahm[i][1,0], -split_contours_hofer_frahm[i][1,-1]], color='mediumblue', linestyle='dotted', linewidth=0.7) + ax[PLT_NUM].plot(split_contours_hofer_frahm[0][0,:], -split_contours_hofer_frahm[0][1,:], color='mediumblue', linewidth=0.7) + ax[PLT_NUM].scatter(ac_coords[1], ac_coords[0], color='red', marker='x') + ax[PLT_NUM].scatter(pc_coords[1], pc_coords[0], color='blue', marker='x') + PLT_NUM += 1 + + reference_contour = split_contours[0] if split_contours is not None else split_contours_hofer_frahm[0] + + ax[PLT_NUM].imshow(transformed[transformed.shape[0]//2], cmap='gray') + #ax[2].imshow(cc_mask, cmap='autumn') + for i in range(len(levelpaths)): + ax[PLT_NUM].plot(levelpaths[i][:,0], -levelpaths[i][:,1], color='brown', linewidth=0.8) + ax[PLT_NUM].set_title('Midline & Levelpaths') + ax[PLT_NUM].plot(midline_equidistant[:,0], -midline_equidistant[:,1], color='red') + ax[PLT_NUM].plot(reference_contour[0,:], -reference_contour[1,:], color='red', linewidth=0.5) + + for a in ax.flatten(): + a.set_aspect('equal', adjustable='box') + a.axis('off') + + # get bounding box of countours + padding = 30 + ax[0].set_xlim(reference_contour[0,:].min()-padding, reference_contour[0,:].max()+padding) + ax[0].set_ylim((-reference_contour[1,:]).max()+padding, (-reference_contour[1,:]).min()-padding) + + + plt.savefig(output_path, dpi=300, bbox_inches='tight') + #plt.show() + + + +def plot_midplane(grid_orig, orig): + """ + Creates a 3D visualization of grid points in original image space. + + Args: + grid_orig: Grid points in original space + orig: Original image for dimension reference + """ + # Create a figure showing grid points in original space + + # Create 3D plot + fig = plt.figure(figsize=(10, 10)) + ax = fig.add_subplot(111, projection='3d') + + # Plot every 10th point to avoid overcrowding + sample_idx = np.arange(0, grid_orig.shape[1], 40) + ax.scatter( + grid_orig[0,sample_idx], + grid_orig[1,sample_idx], + grid_orig[2,sample_idx], + c='r', + alpha=0.1, + marker='.' + ) + + # Set labels + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.set_title('Grid Points in Original Image Space') + + # Set axis limits to image dimensions + ax.set_xlim(0, orig.shape[0]) + ax.set_ylim(0, orig.shape[1]) + ax.set_zlim(0, orig.shape[2]) + + # Save plot + plt.show() + # plt.savefig('grid_points.png') + # plt.close() + diff --git a/requirements.mac.txt b/requirements.mac.txt index 95af69a7..5f8775c6 100644 --- a/requirements.mac.txt +++ b/requirements.mac.txt @@ -16,4 +16,8 @@ torchio>=0.18.83 torchvision>=0.15.2 tqdm>=4.65 yacs>=0.1.8 - +monai>=1.4.0 +meshpy>=2025.1.1 +pyrr>=0.10.3 +whippersnappy>=1.3.1 +pip>=25.0 \ No newline at end of file From e06844df1f802a26cb5cc803acb1b2712f2c7b8c Mon Sep 17 00:00:00 2001 From: ClePol Date: Wed, 17 Sep 2025 12:03:33 +0200 Subject: [PATCH 03/68] updated requirements and removed mri_cc from recon-surf --- {recon_surf => CorpusCallosum}/paint_cc_into_pred.py | 0 env/fastsurfer.yml | 3 +++ pyproject.toml | 5 ++++- recon_surf/recon-surf.sh | 8 ++++---- tools/export_pip-r.sh | 0 5 files changed, 11 insertions(+), 5 deletions(-) rename {recon_surf => CorpusCallosum}/paint_cc_into_pred.py (100%) mode change 100644 => 100755 tools/export_pip-r.sh diff --git a/recon_surf/paint_cc_into_pred.py b/CorpusCallosum/paint_cc_into_pred.py similarity index 100% rename from recon_surf/paint_cc_into_pred.py rename to CorpusCallosum/paint_cc_into_pred.py diff --git a/env/fastsurfer.yml b/env/fastsurfer.yml index e3b991bf..2913c0c7 100644 --- a/env/fastsurfer.yml +++ b/env/fastsurfer.yml @@ -29,3 +29,6 @@ dependencies: - torch==2.6.0+cu126 - torchio==0.20.4 - torchvision==0.21.0+cu126 + - meshpy>=2025.1.1 + - pyrr>=0.10.3 + - whippersnappy>=1.3.1 diff --git a/pyproject.toml b/pyproject.toml index 727065f8..7d05072e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,13 +50,16 @@ dependencies = [ 'torchvision>=0.15.2', 'tqdm>=4.65', 'yacs>=0.1.8', + 'monai>=1.4.0', + 'meshpy>=2025.1.1', + 'pyrr>=0.10.3', + 'whippersnappy>=1.3.1', 'pip>=25.0', ] [project.optional-dependencies] doc = [ 'furo!=2023.8.17', - 'whippersnappy>=1.3.1', 'memory-profiler', 'myst-parser', 'numpydoc', diff --git a/recon_surf/recon-surf.sh b/recon_surf/recon-surf.sh index 64806471..6f3cdb5f 100755 --- a/recon_surf/recon-surf.sh +++ b/recon_surf/recon-surf.sh @@ -627,11 +627,11 @@ fi # create aseg.auto including corpus callosum segmentation and 46 sec, requires norm.mgz # Note: if original input segmentation already contains CC, this will exit with ERROR # in the future maybe check and skip this step (and next) -cmd="mri_cc -aseg $aseg_nocc -o aseg.auto.mgz -lta $mdir/transforms/cc_up.lta $subject" -RunIt "$cmd" "$LF" +#cmd="mri_cc -aseg $aseg_nocc -o aseg.auto.mgz -lta $mdir/transforms/cc_up.lta $subject" +#RunIt "$cmd" "$LF" # add CC into aparc.DKTatlas+aseg.deep (not sure if this is really needed) -cmd="$python ${binpath}paint_cc_into_pred.py -in_cc $mdir/aseg.auto.mgz -in_pred $asegdkt_segfile -out $mdir/aparc.DKTatlas+aseg.deep.withCC.mgz" -RunIt "$cmd" "$LF" +#cmd="$python ${binpath}paint_cc_into_pred.py -in_cc $mdir/aseg.auto.mgz -in_pred $asegdkt_segfile -out $mdir/aparc.DKTatlas+aseg.deep.withCC.mgz" +#RunIt "$cmd" "$LF" # ============================= FILLED ===================================================== diff --git a/tools/export_pip-r.sh b/tools/export_pip-r.sh old mode 100644 new mode 100755 From cdf7687f2a838b5a9b2cf7f1d9080c9daa0a25a1 Mon Sep 17 00:00:00 2001 From: ClePol Date: Fri, 19 Sep 2025 11:46:36 +0200 Subject: [PATCH 04/68] updated requirements, formatting, cleanup --- CorpusCallosum/cc_visualization.py | 124 ++-- CorpusCallosum/data/constants.py | 3 +- CorpusCallosum/data/fsaverage_cc_template.py | 26 +- .../data/generate_fsaverage_centroids.py | 16 +- CorpusCallosum/data/read_write.py | 4 +- CorpusCallosum/fastsurfer_cc.py | 525 ++++++++++---- .../localization/localization_inference.py | 199 +---- .../registration/mapping_helpers.py | 252 ++++--- .../segmentation/segmentation_inference.py | 209 +----- .../segmentation_postprocessing.py | 7 +- CorpusCallosum/shape/cc_endpoint_heuristic.py | 126 +--- CorpusCallosum/shape/cc_mesh.py | 681 ++++++++--------- CorpusCallosum/shape/cc_metrics.py | 128 +--- CorpusCallosum/shape/cc_postprocessing.py | 83 ++- CorpusCallosum/shape/cc_subsegment_contour.py | 686 ++++++------------ CorpusCallosum/shape/cc_thickness.py | 204 +++--- CorpusCallosum/shape/resample_poly.py | 65 -- .../transforms/localization_transforms.py | 3 +- .../transforms/segmentation_transforms.py | 27 +- CorpusCallosum/visualization/visualization.py | 266 +++---- 20 files changed, 1608 insertions(+), 2026 deletions(-) delete mode 100644 CorpusCallosum/shape/resample_poly.py diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index fbae0946..def10b62 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -1,50 +1,72 @@ import argparse from pathlib import Path import numpy as np -from CorpusCallosum.data.fsaverage_cc_template import load_fsaverage_cc_template - +from CorpusCallosum.data.fsaverage_cc_template import load_fsaverage_cc_template from CorpusCallosum.shape.cc_mesh import CC_Mesh + def options_parse() -> argparse.Namespace: - """Parse command line arguments for the visualization pipeline. - """ + """Parse command line arguments for the visualization pipeline.""" parser = argparse.ArgumentParser(description="Visualize corpus callosum from template files.") parser.add_argument("--contours", type=str, required=False, help="Path to contours.txt file", default=None) parser.add_argument("--thickness", type=str, required=True, help="Path to thickness_values.txt file") - parser.add_argument("--measurement_points", type=str, required=True, - help="Path to measurement points file containing the original vertex indices where thickness was measured") + parser.add_argument( + "--measurement_points", + type=str, + required=True, + help="Path to measurement points file containing the original vertex indices where thickness was measured", + ) parser.add_argument("--output_dir", type=str, required=True, help="Directory for output files") parser.add_argument("--resolution", type=float, default=1.0, help="Resolution in mm for the mesh") - parser.add_argument("--smooth_iterations", type=int, default=1, help="Number of smoothing iterations to apply to the mesh") - parser.add_argument("--colormap", type=str, default="red_to_yellow", - choices=["red_to_blue", "blue_to_red", "red_to_yellow", "yellow_to_red"], - help="Colormap to use for thickness visualization") - parser.add_argument("--color_range", type=float, nargs=2, default=None, - metavar=('MIN', 'MAX'), - help="Optional fixed range for the colorbar (min max)") + parser.add_argument( + "--smooth_iterations", type=int, default=1, help="Number of smoothing iterations to apply to the mesh" + ) + parser.add_argument( + "--colormap", + type=str, + default="red_to_yellow", + choices=["red_to_blue", "blue_to_red", "red_to_yellow", "yellow_to_red"], + help="Colormap to use for thickness visualization", + ) + parser.add_argument( + "--color_range", + type=float, + nargs=2, + default=None, + metavar=("MIN", "MAX"), + help="Optional fixed range for the colorbar (min max)", + ) parser.add_argument("--legend", type=str, default="Thickness (mm)", help="Legend for the colorbar") parser.add_argument("--twoD", action="store_true", help="Generate 2D visualization instead of 3D mesh") - + args = parser.parse_args() - + # Create output directory if it doesn't exist Path(args.output_dir).mkdir(parents=True, exist_ok=True) - + return args -def main(contours_path: str | Path | None, thickness_path: str | Path, measurement_points_path: str | Path, - output_dir: str | Path, resolution: float = 1.0, smooth_iterations: int = 1, - colormap: str = "red_to_yellow", color_range: tuple[float, float] | None = None, - legend: str | None = None, twoD: bool = False) -> None: +def main( + contours_path: str | Path | None, + thickness_path: str | Path, + measurement_points_path: str | Path, + output_dir: str | Path, + resolution: float = 1.0, + smooth_iterations: int = 1, + colormap: str = "red_to_yellow", + color_range: tuple[float, float] | None = None, + legend: str | None = None, + twoD: bool = False, +) -> None: """Main function to visualize corpus callosum from template files. - + This function: 1. Loads contours and thickness values from template files 2. Creates a CC_Mesh object 3. Generates and saves visualizations - + Args: contours_path: Path to contours.txt file thickness_path: Path to thickness_values.txt file @@ -65,7 +87,7 @@ def main(contours_path: str | Path | None, thickness_path: str | Path, measureme thickness_path = Path(thickness_path) measurement_points_path = Path(measurement_points_path) output_dir = Path(output_dir) - + # Load data and create mesh cc_mesh = CC_Mesh(num_slices=1) # Will be resized when loading data @@ -76,45 +98,53 @@ def main(contours_path: str | Path | None, thickness_path: str | Path, measureme cc_mesh.contours[0] = np.stack(cc_contour).T cc_mesh.start_end_idx[0] = [anterior_endpoint_idx, posterior_endpoint_idx] - cc_mesh.load_thickness_values(str(thickness_path), str(measurement_points_path)) cc_mesh.set_resolution(resolution) if twoD: - #cc_mesh.smooth_contour(contour_idx=0, window_size=5) - cc_mesh.plot_cc_contour_with_levelsets(contour_idx=0, levelpaths=None, title=None, save_path=str(output_dir / 'cc_thickness_2d.png'), colorbar=True) + # cc_mesh.smooth_contour(contour_idx=0, window_size=5) + cc_mesh.plot_cc_contour_with_levelsets( + contour_idx=0, levelpaths=None, title=None, save_path=str(output_dir / "cc_thickness_2d.png"), colorbar=True + ) else: cc_mesh.fill_thickness_values() # Create and process mesh cc_mesh.create_mesh(smooth=smooth_iterations, closed=False) - - # Generate visualizations - cc_mesh.plot_mesh(colormap=colormap, color_range=color_range, thickness_overlay=True, show_contours=False, show_mesh_edges=True, legend=legend) - cc_mesh.plot_mesh(str(output_dir / 'cc_mesh.html'), thickness_overlay=True) + cc_mesh.plot_mesh( + colormap=colormap, + color_range=color_range, + thickness_overlay=True, + show_contours=False, + show_mesh_edges=True, + legend=legend, + ) + cc_mesh.plot_mesh(str(output_dir / "cc_mesh.html"), thickness_overlay=True) + + cc_mesh.plot_cc_contour_with_levelsets( + contour_idx=len(cc_mesh.contours) // 2, save_path=str(output_dir / "midslice_2d.png") + ) - cc_mesh.plot_cc_contour_with_levelsets(contour_idx=len(cc_mesh.contours)//2, save_path=str(output_dir / 'midslice_2d.png')) - cc_mesh.to_fs_coordinates() - cc_mesh.write_vtk(str(output_dir / 'cc_mesh.vtk')) - cc_mesh.write_fssurf(str(output_dir / 'cc_mesh.fssurf')) - cc_mesh.write_overlay(str(output_dir / 'cc_mesh_overlay.curv')) - cc_mesh.snap_cc_picture(str(output_dir / 'cc_mesh_snap.png')) + cc_mesh.write_vtk(str(output_dir / "cc_mesh.vtk")) + cc_mesh.write_fssurf(str(output_dir / "cc_mesh.fssurf")) + cc_mesh.write_overlay(str(output_dir / "cc_mesh_overlay.curv")) + cc_mesh.snap_cc_picture(str(output_dir / "cc_mesh_snap.png")) if __name__ == "__main__": options = options_parse() main_args = { - 'contours_path': options.contours, - 'thickness_path': options.thickness, - 'measurement_points_path': options.measurement_points, - 'output_dir': options.output_dir, - 'resolution': options.resolution, - 'smooth_iterations': options.smooth_iterations, - 'colormap': options.colormap, - 'color_range': options.color_range, - 'legend': options.legend, - 'twoD': options.twoD + "contours_path": options.contours, + "thickness_path": options.thickness, + "measurement_points_path": options.measurement_points, + "output_dir": options.output_dir, + "resolution": options.resolution, + "smooth_iterations": options.smooth_iterations, + "colormap": options.colormap, + "color_range": options.color_range, + "legend": options.legend, + "twoD": options.twoD, } - main(**main_args) \ No newline at end of file + main(**main_args) diff --git a/CorpusCallosum/data/constants.py b/CorpusCallosum/data/constants.py index 9c70642e..60125bd2 100644 --- a/CorpusCallosum/data/constants.py +++ b/CorpusCallosum/data/constants.py @@ -18,5 +18,6 @@ "orient_volume_lta": "transforms/orient_volume.lta", "orig_space_segmentation": "mri/segmentation_orig_space.mgz", "debug_image": "stats/cc_postprocessing.png", - #"qc_view": "stats/qc_view.png" + "qc_view": "qc-snapshots/corpus_callosum.png", + "qc_view3d": "qc-snapshots/corpus_callosum_thickness.png" } \ No newline at end of file diff --git a/CorpusCallosum/data/fsaverage_cc_template.py b/CorpusCallosum/data/fsaverage_cc_template.py index 0f8329db..00736783 100644 --- a/CorpusCallosum/data/fsaverage_cc_template.py +++ b/CorpusCallosum/data/fsaverage_cc_template.py @@ -1,10 +1,11 @@ -import nibabel as nib -import matplotlib.pyplot as plt -from shape.cc_postprocessing import process_slice -from pathlib import Path import os -from scipy import ndimage +from pathlib import Path + +import nibabel as nib import numpy as np +from scipy import ndimage +from shape.cc_postprocessing import process_slice + def smooth_contour(contour, window_size=5): """ @@ -50,7 +51,9 @@ def load_fsaverage_cc_template(): freesurfer_home = Path(os.environ['FREESURFER_HOME']) if not freesurfer_home.exists(): - raise EnvironmentError(f"FREESURFER_HOME environment variable is not set correctly or does not exist: {freesurfer_home}, either provide your own template or set the FREESURFER_HOME environment variable") + raise OSError(f"FREESURFER_HOME environment variable is not set correctly or does not exist: " + f"{freesurfer_home}, either provide your own template or set the " + f"FREESURFER_HOME environment variable") fsaverage_seg_path = freesurfer_home / 'subjects' / 'fsaverage' / 'mri' / 'aparc+aseg.mgz' fsaverage_seg = nib.load(fsaverage_seg_path) @@ -84,7 +87,16 @@ def load_fsaverage_cc_template(): cc_mask = cc_mask_smoothed.astype(int) cc_mask[cc_mask > 0] = 192 - output_dict, contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx = process_slice(cc_mask[None], 0, AC, PC, fsaverage_seg.affine, 100, [1/6, 1/2, 2/3, 3/4], "shape", 1.0, verbose=False) + (_, contour_with_thickness, anterior_endpoint_idx, + posterior_endpoint_idx) = process_slice(segmentation=cc_mask[None], + slice_idx=0, + ac_coords=AC, + pc_coords=PC, + affine=fsaverage_seg.affine, + num_thickness_points=100, + subdivisions=[1/6, 1/2, 2/3, 3/4], + subdivision_method="shape", + contour_smoothing=1.0) outside_contour = contour_with_thickness[0].T diff --git a/CorpusCallosum/data/generate_fsaverage_centroids.py b/CorpusCallosum/data/generate_fsaverage_centroids.py index 46ac01bd..f97a1777 100644 --- a/CorpusCallosum/data/generate_fsaverage_centroids.py +++ b/CorpusCallosum/data/generate_fsaverage_centroids.py @@ -22,18 +22,18 @@ def main(): try: fs_home = Path(os.environ['FREESURFER_HOME']) if not fs_home.exists(): - raise EnvironmentError(f"FREESURFER_HOME environment variable is not set correctly or does not exist: {fs_home}") + raise OSError(f"FREESURFER_HOME environment variable is not set correctly or does not exist: {fs_home}") fsaverage_path = fs_home / 'subjects' / 'fsaverage' if not fsaverage_path.exists(): - raise EnvironmentError(f"fsaverage path does not exist: {fsaverage_path}") + raise OSError(f"fsaverage path does not exist: {fsaverage_path}") fsaverage_aseg_path = fsaverage_path / 'mri' / 'aseg.mgz' if not fsaverage_aseg_path.exists(): raise FileNotFoundError(f"fsaverage aseg file does not exist: {fsaverage_aseg_path}") - except KeyError: - raise EnvironmentError("FREESURFER_HOME environment variable is not set") + except KeyError as err: + raise OSError("FREESURFER_HOME environment variable is not set") from err print(f"Loading fsaverage segmentation from: {fsaverage_aseg_path}") @@ -102,20 +102,20 @@ def main(): # Print some statistics label_ids = list(centroids_dst.keys()) print(f"Label IDs range: {min(label_ids)} to {max(label_ids)}") - print(f"Sample centroids:") + print("Sample centroids:") for label_id in sorted(label_ids)[:5]: centroid = centroids_dst[label_id] print(f" Label {label_id}: [{centroid[0]:.2f}, {centroid[1]:.2f}, {centroid[2]:.2f}]") - print(f"Fsaverage affine matrix:") + print("Fsaverage affine matrix:") print(fsaverage_affine) - print(f"Fsaverage header fields:") + print("Fsaverage header fields:") print(f" dims: {dims}") print(f" delta: {delta}") print(f" Mdc shape: {Mdc.shape}") print(f" Pxyz_c: {Pxyz_c}") - print(f"Combined data structure created successfully") + print("Combined data structure created successfully") if __name__ == "__main__": diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py index 3db13150..16947de7 100644 --- a/CorpusCallosum/data/read_write.py +++ b/CorpusCallosum/data/read_write.py @@ -158,7 +158,7 @@ def load_fsaverage_centroids(centroids_path): if not centroids_path.exists(): raise FileNotFoundError(f"Fsaverage centroids file not found: {centroids_path}") - with open(centroids_path, 'r') as f: + with open(centroids_path) as f: centroids_data = json.load(f) # Convert string keys back to integers and lists back to numpy arrays @@ -230,7 +230,7 @@ def load_fsaverage_data(data_path): if not data_path.exists(): raise FileNotFoundError(f"Fsaverage data file not found: {data_path}") - with open(data_path, 'r') as f: + with open(data_path) as f: data = json.load(f) # Verify required fields diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 7f2f9f6a..14171727 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -1,6 +1,7 @@ import argparse import json import warnings + warnings.filterwarnings("ignore", message="TypedStorage is deprecated") from pathlib import Path @@ -8,77 +9,184 @@ import nibabel as nib import numpy as np import torch - -from localization import localization_inference -from segmentation import segmentation_inference, segmentation_postprocessing +from CorpusCallosum.localization import localization_inference from recon_surf import lta -from CorpusCallosum.registration.mapping_helpers import interpolate_midplane, get_mapping_to_standard_space, map_softlabels_to_orig, apply_transform_to_pt, apply_transform_and_map_volume -from CorpusCallosum.shape.cc_postprocessing import process_slices, create_visualization - -from FastSurferCNN.data_loader.conform import is_conform from recon_surf.align_points import find_rigid -from CorpusCallosum.data.read_write import save_nifti_background, get_centroids_from_nib, convert_numpy_to_json_serializable, run_in_background, load_fsaverage_centroids, load_fsaverage_data - -from CorpusCallosum.data.constants import * - +from CorpusCallosum.segmentation import segmentation_inference, segmentation_postprocessing +from FastSurferCNN.data_loader.conform import is_conform +from CorpusCallosum.data.constants import ( + FSAVERAGE_CENTROIDS_PATH, + FSAVERAGE_DATA_PATH, + WEIGHTS_PATH, + STANDARD_OUTPUT_PATHS, + FSAVERAGE_MIDDLE, + CC_LABEL, +) +from CorpusCallosum.data.read_write import ( + convert_numpy_to_json_serializable, + get_centroids_from_nib, + load_fsaverage_centroids, + load_fsaverage_data, + run_in_background, + save_nifti_background, +) +from CorpusCallosum.registration.mapping_helpers import ( + apply_transform_and_map_volume, + apply_transform_to_pt, + get_mapping_to_standard_space, + interpolate_midplane, + map_softlabels_to_orig, +) +from CorpusCallosum.shape.cc_postprocessing import create_visualization, process_slices def options_parse() -> argparse.Namespace: - """Parse command line arguments for the pipeline. - """ + """Parse command line arguments for the pipeline.""" parser = argparse.ArgumentParser() - parser.add_argument("--in_mri", type=str, required=False, help="Input MRI file path. If not provided, defaults to subject_dir/mri/orig.mgz") - parser.add_argument("--aseg", type=str, required=False, help="Input segmentation file path. If not provided, defaults to subject_dir/mri/aparc.DKTatlas+aseg.deep.mgz") - parser.add_argument("--subject_dir", type=str, required=False, help="Subject directory containing standard FreeSurfer structure. Required if --in_mri and --aseg are not both provided.", default=None) + parser.add_argument( + "--in_mri", + type=str, + required=False, + help="Input MRI file path. If not provided, defaults to subject_dir/mri/orig.mgz", + ) + parser.add_argument( + "--cpu", + action="store_true", + help="Force CPU usage even when CUDA is available", + ) + parser.add_argument( + "--aseg", + type=str, + required=False, + help="Input segmentation file path. If not provided, defaults to subject_dir/mri/aparc.DKTatlas+aseg.deep.mgz", + ) + parser.add_argument( + "--subject_dir", + type=str, + required=False, + help="Subject directory containing standard FreeSurfer structure. " + "Required if --in_mri and --aseg are not both provided.", + default=None, + ) parser.add_argument("--debug_output_dir", type=str, required=False, default=None) parser.add_argument("--verbose", action="store_true", help="Enable verbose output and debug plots") - + # CC shape arguments - parser.add_argument("--num_thickness_points", type=int, default=100, help="Number of points for thickness estimation.") - parser.add_argument("--subdivisions", type=float, nargs='+', default=[1/6, 1/2, 2/3, 3/4], help="List of subdivision fractions for the corpus callosum subsegmentation.") - parser.add_argument("--subdivision_method", type=str, default="shape", help="Method for contour subdivision. \ + parser.add_argument( + "--num_thickness_points", type=int, default=100, help="Number of points for thickness estimation." + ) + parser.add_argument( + "--subdivisions", + type=float, + nargs="+", + default=[1 / 6, 1 / 2, 2 / 3, 3 / 4], + help="List of subdivision fractions for the corpus callosum subsegmentation.", + ) + parser.add_argument( + "--subdivision_method", + type=str, + default="shape", + help="Method for contour subdivision. \ Options: shape (Intercallosal subdivision perpendicular to intercallosal line), vertical \ (orthogonal to the most anterior and posterior points in the AC/PC standardized CC contour), \ angular (subdivision based on equally spaced angles, as proposed by Hampel and colleagues), \ - eigenvector (primary direction, same as FreeSurfers mri_cc)", choices=["shape", "vertical", "angular", "eigenvector"]) - parser.add_argument("--contour_smoothing", type=float, default=1.0, help="Gaussian sigma for smoothing during contour detection. Default is 1.0, higher values mean a smoother outline, at the cost of precision.") - parser.add_argument("--slice_selection", type=str, default="middle", help="Which slices to process. Options: 'middle' (default), 'all', or a specific slice number.") - + eigenvector (primary direction, same as FreeSurfers mri_cc)", + choices=["shape", "vertical", "angular", "eigenvector"], + ) + parser.add_argument( + "--contour_smoothing", + type=float, + default=1.0, + help="Gaussian sigma for smoothing during contour detection. Default is 1.0, higher values mean a smoother" + "outline, at the cost of precision.", + ) + parser.add_argument( + "--slice_selection", + type=str, + default="middle", + help="Which slices to process. Options: 'middle' (default), 'all', or a specific slice number.", + ) + # Output path arguments - parser.add_argument("--upright_volume_path", type=str, help="Path for upright volume output (default: subject_dir/stats/upright_volume.mgz)", default=None) - parser.add_argument("--segmentation_path", type=str, help="Path for segmentation output (default: subject_dir/stats/cc_segmentation.mgz)", default=None) - parser.add_argument("--postproc_results_path", type=str, help="Path for postprocessing results (default: subject_dir/stats/cc_postproc_results.json)", default=None) - parser.add_argument("--cc_markers_path", type=str, help="Path for CC markers output (default: subject_dir/stats/cc_markers.json)", default=None) - parser.add_argument("--upright_lta_path", type=str, help="Path for upright LTA transform (default: subject_dir/transforms/upright.lta)", default=None) - parser.add_argument("--orient_volume_lta_path", type=str, help="Path for orientation volume LTA transform (default: subject_dir/transforms/orient_volume.lta)", default=None) - parser.add_argument("--orig_space_segmentation_path", type=str, help="Path for segmentation in original space (default: subject_dir/mri/segmentation_orig_space.mgz)", default=None) - parser.add_argument("--debug_image_path", type=str, help="Path for debug visualization image (default: subject_dir/stats/cc_postprocessing.png)", default=None) - + parser.add_argument( + "--upright_volume_path", + type=str, + help="Path for upright volume output (default: subject_dir/stats/upright_volume.mgz)", + default=None, + ) + parser.add_argument( + "--segmentation_path", + type=str, + help="Path for segmentation output (default: subject_dir/stats/cc_segmentation.mgz)", + default=None, + ) + parser.add_argument( + "--postproc_results_path", + type=str, + help="Path for postprocessing results (default: subject_dir/stats/cc_postproc_results.json)", + default=None, + ) + parser.add_argument( + "--cc_markers_path", + type=str, + help="Path for CC markers output (default: subject_dir/stats/cc_markers.json)", + default=None, + ) + parser.add_argument( + "--upright_lta_path", + type=str, + help="Path for upright LTA transform (default: subject_dir/transforms/upright.lta)", + default=None, + ) + parser.add_argument( + "--orient_volume_lta_path", + type=str, + help="Path for orientation volume LTA transform (default: subject_dir/transforms/orient_volume.lta)", + default=None, + ) + parser.add_argument( + "--orig_space_segmentation_path", + type=str, + help="Path for segmentation in original space (default: subject_dir/mri/segmentation_orig_space.mgz)", + default=None, + ) + parser.add_argument( + "--debug_image_path", + type=str, + help="Path for debug visualization image (default: subject_dir/stats/cc_postprocessing.png)", + default=None, + ) + # Template saving argument - parser.add_argument("--save_template", type=str, help="Directory path where to save contours.txt and thickness_values.txt files", default=None) - + parser.add_argument( + "--save_template", + type=str, + help="Directory path where to save contours.txt and thickness_values.txt files", + default=None, + ) + args = parser.parse_args() - + # Validation logic: either subject_dir OR both in_mri and aseg must be provided if not args.subject_dir and (not args.in_mri or not args.aseg): parser.error("You must specify either --subject_dir OR both --in_mri and --aseg arguments.") - + # If subject_dir is provided, set default paths for missing arguments if args.subject_dir: subject_dir_path = Path(args.subject_dir) - + # Create standard FreeSurfer subdirectories (subject_dir_path / "mri").mkdir(parents=True, exist_ok=True) (subject_dir_path / "stats").mkdir(parents=True, exist_ok=True) (subject_dir_path / "transforms").mkdir(parents=True, exist_ok=True) - + if not args.in_mri: args.in_mri = str(subject_dir_path / "mri" / "orig.mgz") - + if not args.aseg: args.aseg = str(subject_dir_path / "mri" / "aparc.DKTatlas+aseg.deep.mgz") - + # Set default output paths if not provided for key, value in STANDARD_OUTPUT_PATHS.items(): if not getattr(args, f"{key}_path"): @@ -87,25 +195,31 @@ def options_parse() -> argparse.Namespace: # Set output_dir to subject_dir args.output_dir = str(subject_dir_path) - # Create parent directories for all output paths - for path in [args.upright_volume_path, args.segmentation_path, args.postproc_results_path, args.cc_markers_path, args.upright_lta_path, args.orient_volume_lta_path]: + for path in [ + args.upright_volume_path, + args.segmentation_path, + args.postproc_results_path, + args.cc_markers_path, + args.upright_lta_path, + args.orient_volume_lta_path, + ]: if path is not None: Path(path).parent.mkdir(parents=True, exist_ok=True) - + return args def centroid_registration(aseg_nib, verbose=False): """Perform centroid-based registration between subject and fsaverage space. - + Computes a rigid transformation between the subject's segmentation and fsaverage space by aligning centroids of corresponding anatomical structures. - + Args: aseg_nib (nib.Nifti1Image): Subject's segmentation image verbose (bool): Whether to print progress information - + Returns: tuple: Contains: - orig_fsaverage_vox2vox: Transformation matrix from original to fsaverage voxel space @@ -119,7 +233,7 @@ def centroid_registration(aseg_nib, verbose=False): # Load pre-computed fsaverage centroids and data from static files centroids_dst = load_fsaverage_centroids(FSAVERAGE_CENTROIDS_PATH) fsaverage_affine, fsaverage_header = load_fsaverage_data(FSAVERAGE_DATA_PATH) - + centroids_mov, ids_not_found = get_centroids_from_nib(aseg_nib, label_ids=list(centroids_dst.keys())) # delete not found labels from centroids_mov @@ -138,7 +252,9 @@ def centroid_registration(aseg_nib, verbose=False): resolution_trans[1, 1] = resolution_orig resolution_trans[2, 2] = resolution_orig - orig_fsaverage_vox2vox = np.linalg.inv(resolution_trans @ fsaverage_affine) @ orig_fsaverage_ras2ras @ aseg_nib.affine + orig_fsaverage_vox2vox = ( + np.linalg.inv(resolution_trans @ fsaverage_affine) @ orig_fsaverage_ras2ras @ aseg_nib.affine + ) fsaverage_hires_affine = resolution_trans @ fsaverage_affine return orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, fsaverage_hires_affine, fsaverage_header @@ -146,10 +262,10 @@ def centroid_registration(aseg_nib, verbose=False): def localize_ac_pc(midslices, aseg_nib, orig_fsaverage_vox2vox, model_localization, slices_to_analyze, verbose=False): """Localize anterior and posterior commissure points in the brain. - + Uses a trained model to detect AC and PC points in mid-sagittal slices, using the third ventricle as an anatomical reference. - + Args: midslices (np.ndarray): Array of mid-sagittal slices aseg_nib (nib.Nifti1Image): Subject's segmentation image @@ -158,7 +274,7 @@ def localize_ac_pc(midslices, aseg_nib, orig_fsaverage_vox2vox, model_localizati model_localization: Trained model for AC-PC detection slices_to_analyze (int): Number of slices to process verbose (bool): Whether to print progress information - + Returns: tuple: Contains: - ac_coords (np.ndarray): Coordinates of the anterior commissure @@ -174,18 +290,22 @@ def localize_ac_pc(midslices, aseg_nib, orig_fsaverage_vox2vox, model_localizati # get 5 mm of slices output with 3 slices per inference midslices_middle = midslices.shape[0] // 2 - middle_slices_localization = midslices[midslices_middle-slices_to_analyze//2-1:midslices_middle+slices_to_analyze//2+2] - ac_coords, pc_coords = localization_inference.run_inference_on_slice(model_localization, middle_slices_localization, third_ventricle_center_vox[1:]) + middle_slices_localization = midslices[ + midslices_middle - slices_to_analyze // 2 - 1 : midslices_middle + slices_to_analyze // 2 + 2 + ] + ac_coords, pc_coords = localization_inference.run_inference_on_slice( + model_localization, middle_slices_localization, third_ventricle_center_vox[1:] + ) return ac_coords, pc_coords def segment_cc(midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, slices_to_analyze): """Segment the corpus callosum using a trained model. - + Performs corpus callosum segmentation on mid-sagittal slices using a trained model, with AC-PC points as anatomical references. Includes post-processing to clean the segmentation. - + Args: midslices (np.ndarray): Array of mid-sagittal slices ac_coords (np.ndarray): Anterior commissure coordinates @@ -196,7 +316,7 @@ def segment_cc(midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, sl model_segmentation: Trained model for CC segmentation slices_to_analyze (int): Number of slices to process verbose (bool): Whether to print progress information - + Returns: tuple: Contains: - segmentation (np.ndarray): Binary segmentation of the corpus callosum @@ -204,17 +324,27 @@ def segment_cc(midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, sl """ # get 5 mm of slices output with 9 slices per inference midslices_middle = midslices.shape[0] // 2 - middle_slices_segmentation = midslices[midslices_middle-slices_to_analyze//2-4:midslices_middle+slices_to_analyze//2+5] - segmentation, inputs, outputs_avg, outputs_soft = segmentation_inference.run_inference_on_slice(model_segmentation, - middle_slices_segmentation, - AC_center=ac_coords, PC_center=pc_coords, - voxel_size=aseg_nib.header.get_zooms()[0]) + middle_slices_segmentation = midslices[ + midslices_middle - slices_to_analyze // 2 - 4 : midslices_middle + slices_to_analyze // 2 + 5 + ] + segmentation, inputs, outputs_avg, outputs_soft = segmentation_inference.run_inference_on_slice( + model_segmentation, + middle_slices_segmentation, + AC_center=ac_coords, + PC_center=pc_coords, + voxel_size=aseg_nib.header.get_zooms()[0], + ) pre_clean_segmentation = segmentation.copy() segmentation, cc_volume_mask = segmentation_postprocessing.clean_cc_segmentation(segmentation) # print a warning if the cc_volume_mask touches the edge of the segmentation - if np.any(cc_volume_mask[:,0,:]) or np.any(cc_volume_mask[:,-1,:]) or np.any(cc_volume_mask[:,:,0]) or np.any(cc_volume_mask[:,:,-1]): + if ( + np.any(cc_volume_mask[:, 0, :]) + or np.any(cc_volume_mask[:, -1, :]) + or np.any(cc_volume_mask[:, :, 0]) + or np.any(cc_volume_mask[:, :, -1]) + ): print("Warning: CC volume mask touches the edge of the segmentation field-of-view, CC might be truncated") # get voxels that were removed during cleaning @@ -224,17 +354,30 @@ def segment_cc(midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, sl return segmentation, outputs_soft -def main(in_mri_path: str | Path, aseg_path: str | Path, output_dir: str | Path, slice_selection: str = "middle", - debug_output_dir: str | Path = None, verbose: bool = False, num_thickness_points: int = 100, - subdivisions: list[float] | None = None, subdivision_method: str = "shape", - contour_smoothing: float = 1.0, - upright_volume_path: str | Path = None, segmentation_path: str | Path = None, - postproc_results_path: str | Path = None, cc_markers_path: str | Path = None, - upright_lta_path: str | Path = None, orient_volume_lta_path: str | Path = None, - orig_space_segmentation_path: str | Path = None, debug_image_path: str | Path = None, - save_template: str | Path | None = None) -> None: +def main( + in_mri_path: str | Path, + aseg_path: str | Path, + output_dir: str | Path, + slice_selection: str = "middle", + debug_output_dir: str | Path = None, + verbose: bool = False, + num_thickness_points: int = 100, + subdivisions: list[float] | None = None, + subdivision_method: str = "shape", + contour_smoothing: float = 1.0, + upright_volume_path: str | Path = None, + segmentation_path: str | Path = None, + postproc_results_path: str | Path = None, + cc_markers_path: str | Path = None, + upright_lta_path: str | Path = None, + orient_volume_lta_path: str | Path = None, + orig_space_segmentation_path: str | Path = None, + debug_image_path: str | Path = None, + save_template: str | Path | None = None, + cpu: bool = False, +) -> None: """Main pipeline function for corpus callosum analysis. - + This function performs the following steps: 1. Initializes environment and loads models 2. Registers input image to fsaverage space @@ -242,7 +385,7 @@ def main(in_mri_path: str | Path, aseg_path: str | Path, output_dir: str | Path, 4. Segments the corpus callosum 5. Performs enhanced post-processing analysis 6. Saves results and visualizations - + Args: in_mri_path: Path to input MRI file aseg_path: Path to input segmentation file @@ -260,10 +403,12 @@ def main(in_mri_path: str | Path, aseg_path: str | Path, output_dir: str | Path, cc_markers_path: Path for CC markers output (default: output_dir/cc_markers.json) upright_lta_path: Path for upright LTA transform (default: output_dir/upright.lta) orient_volume_lta_path: Path for orientation volume LTA transform (default: output_dir/orient_volume.lta) - orig_space_segmentation_path: Path for segmentation in original space (default: output_dir/mri/segmentation_orig_space.mgz) + orig_space_segmentation_path: Path for segmentation in original space + (default: output_dir/mri/segmentation_orig_space.mgz) debug_image_path: Path for debug visualization image (default: output_dir/stats/cc_postprocessing.png) save_template: Directory path where to save contours.txt and thickness_values.txt files - + cpu: Force CPU usage even when CUDA is available + The function saves multiple outputs to specified paths or default locations in output_dir: - cc_markers.json: Contains detected landmarks and measurements - midplane_slices.mgz: Extracted midplane slices @@ -272,28 +417,27 @@ def main(in_mri_path: str | Path, aseg_path: str | Path, output_dir: str | Path, - cc_postproc_results.json: Enhanced postprocessing results - Various visualization plots and transformation matrices """ - + if subdivisions is None: - subdivisions = [1/6, 1/2, 2/3, 3/4] - + subdivisions = [1 / 6, 1 / 2, 2 / 3, 3 / 4] + # Convert all paths to Path objects in_mri_path = Path(in_mri_path) aseg_path = Path(aseg_path) output_dir = Path(output_dir) debug_output_dir = Path(debug_output_dir) if debug_output_dir else None save_template = Path(save_template) if save_template else None - + # Validate subdivision fractions for i in subdivisions: if i < 0 or i > 1: - print('Error: Subdivision fractions must be between 0 and 1, but got: ', i) + print("Error: Subdivision fractions must be between 0 and 1, but got: ", i) exit(1) #### setup variables IO_processes = [] - + orig = nib.load(in_mri_path) - # 5 mm around the midplane slices_to_analyze = int(np.ceil(5 / orig.header.get_zooms()[0])) @@ -301,45 +445,75 @@ def main(in_mri_path: str | Path, aseg_path: str | Path, output_dir: str | Path, slices_to_analyze += 1 if verbose: - print(f"Segmenting {slices_to_analyze} slices (5 mm width at {orig.header.get_zooms()[0]} mm resolution, center around the mid-sagittal plane)") - + print( + f"Segmenting {slices_to_analyze} slices (5 mm width at {orig.header.get_zooms()[0]} mm resolution, " + "center around the mid-sagittal plane)" + ) if not is_conform(orig, conform_vox_size=orig.header.get_zooms()[0]): print("Error: MRI is not conformed, please run conform.py or mri_convert to conform the image.") exit(1) # load models - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model_localization = localization_inference.load_model(str(Path(WEIGHTS_PATH) / "localization_weights_acpc.pth"), device=device) - model_segmentation = segmentation_inference.load_model(str(Path(WEIGHTS_PATH) / "segmentation_weights_cc_fn.pth"), device=device) - + device = torch.device("cuda" if torch.cuda.is_available() and not cpu else "cpu") + model_localization = localization_inference.load_model( + str(Path(WEIGHTS_PATH) / "localization_weights_acpc.pth"), device=device + ) + model_segmentation = segmentation_inference.load_model( + str(Path(WEIGHTS_PATH) / "segmentation_weights_cc_fn.pth"), device=device + ) aseg_nib = nib.load(aseg_path) - orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, fsaverage_hires_affine, fsaverage_header = centroid_registration(aseg_nib, verbose) + orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, fsaverage_hires_affine, fsaverage_header = centroid_registration( + aseg_nib, verbose + ) if verbose: print("Interpolating midplane") - + # this is a fast interpolation to not block the main thread midslices = interpolate_midplane(orig, orig_fsaverage_vox2vox, slices_to_analyze) - # start saving upright volume - IO_processes.append(run_in_background(apply_transform_and_map_volume, False, - orig.get_fdata(), orig_fsaverage_vox2vox, fsaverage_hires_affine, None, upright_volume_path, output_size=np.array([256,256,256]))) - + IO_processes.append( + run_in_background( + apply_transform_and_map_volume, + False, + orig.get_fdata(), + orig_fsaverage_vox2vox, + fsaverage_hires_affine, + None, + upright_volume_path, + output_size=np.array([256, 256, 256]), + ) + ) + #### do localization and segmentation inference - ac_coords, pc_coords = localize_ac_pc(midslices, aseg_nib, orig_fsaverage_vox2vox, model_localization, slices_to_analyze, verbose) - segmentation, outputs_soft = segment_cc(midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, slices_to_analyze) - + ac_coords, pc_coords = localize_ac_pc( + midslices, aseg_nib, orig_fsaverage_vox2vox, model_localization, slices_to_analyze, verbose + ) + segmentation, outputs_soft = segment_cc( + midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, slices_to_analyze + ) + # map soft labels to original space (in parallel because this takes a while) - IO_processes.append(run_in_background(map_softlabels_to_orig, False, - outputs_soft, orig_fsaverage_vox2vox, orig, slices_to_analyze, orig_space_segmentation_path, fsaverage_middle=FSAVERAGE_MIDDLE)) + IO_processes.append( + run_in_background( + map_softlabels_to_orig, + False, + outputs_soft, + orig_fsaverage_vox2vox, + orig, + slices_to_analyze, + orig_space_segmentation_path, + fsaverage_middle=FSAVERAGE_MIDDLE, + ) + ) # Create a temporary segmentation image with proper affine for enhanced postprocessing temp_seg_affine = fsaverage_hires_affine @ np.linalg.inv(np.eye(4)) - + # Process slices based on selection mode slice_results, slice_io_processes = process_slices( segmentation=segmentation, @@ -356,34 +530,39 @@ def main(in_mri_path: str | Path, aseg_path: str | Path, output_dir: str | Path, debug_image_path=debug_image_path, vox_size=orig.header.get_zooms()[0], verbose=verbose, - save_template=save_template + save_template=save_template, ) IO_processes.extend(slice_io_processes) - + # Get middle slice result for backward compatibility - middle_slice_result = slice_results[len(slice_results)//2] - + middle_slice_result = slice_results[len(slice_results) // 2] + # Create enhanced output dictionary with all slice results per_slice_output_dict = { - 'slices': [convert_numpy_to_json_serializable({ - 'slice_index': result['slice_index'], - 'cc_index': result['cc_index'], - 'circularity': result['circularity'], - 'areas': result['areas'], - 'midline_length': result['midline_length'], - 'thickness': result['thickness'], - 'curvature': result['curvature'], - 'thickness_profile': result['thickness_profile'], - 'total_area': result['total_area'], - 'total_perimeter': result['total_perimeter'] - }) for result in slice_results], - 'slices_in_segmentation': segmentation.shape[0], - 'voxel_size': [float(x) for x in orig.header.get_zooms()], - 'subdivision_method': subdivision_method, - 'num_thickness_points': num_thickness_points, - 'subdivisions': subdivisions, - 'contour_smoothing': contour_smoothing, - 'slice_selection': slice_selection + "slices": [ + convert_numpy_to_json_serializable( + { + "slice_index": result["slice_index"], + "cc_index": result["cc_index"], + "circularity": result["circularity"], + "areas": result["areas"], + "midline_length": result["midline_length"], + "thickness": result["thickness"], + "curvature": result["curvature"], + "thickness_profile": result["thickness_profile"], + "total_area": result["total_area"], + "total_perimeter": result["total_perimeter"], + } + ) + for result in slice_results + ], + "slices_in_segmentation": segmentation.shape[0], + "voxel_size": [float(x) for x in orig.header.get_zooms()], + "subdivision_method": subdivision_method, + "num_thickness_points": num_thickness_points, + "subdivisions": subdivisions, + "contour_smoothing": contour_smoothing, + "slice_selection": slice_selection, } # Save slice-wise postprocessing results to JSON @@ -392,53 +571,73 @@ def main(in_mri_path: str | Path, aseg_path: str | Path, output_dir: str | Path, if verbose: print(f"Multiple slice post-processing results saved to {postproc_results_path}") - + ########## Save outputs ########## - cc_volume = segmentation_postprocessing.get_cc_volume(desired_width_mm=5, cc_mask=segmentation == CC_LABEL, voxel_size=orig.header.get_zooms()) + cc_volume = segmentation_postprocessing.get_cc_volume( + desired_width_mm=5, cc_mask=segmentation == CC_LABEL, voxel_size=orig.header.get_zooms() + ) # Create backward compatible output_dict for existing pipeline using middle slice output_dict = { - 'areas': middle_slice_result['areas'], - 'areas_hofer_frahm': middle_slice_result['areas'] if middle_slice_result['split_contours_hofer_frahm'] is not None else middle_slice_result['areas'], - 'thickness': middle_slice_result['thickness'], - 'curvature': middle_slice_result['curvature'], - 'midline_length': middle_slice_result['midline_length'], - 'circularity': middle_slice_result['circularity'], - 'cc_index': middle_slice_result['cc_index'], - 'total_area': middle_slice_result['total_area'], - 'total_perimeter': middle_slice_result['total_perimeter'], - 'thickness_profile': middle_slice_result['thickness_profile'] + "areas": middle_slice_result["areas"], + "areas_hofer_frahm": middle_slice_result["areas"] + if middle_slice_result["split_contours_hofer_frahm"] is not None + else middle_slice_result["areas"], + "thickness": middle_slice_result["thickness"], + "curvature": middle_slice_result["curvature"], + "midline_length": middle_slice_result["midline_length"], + "circularity": middle_slice_result["circularity"], + "cc_index": middle_slice_result["cc_index"], + "total_area": middle_slice_result["total_area"], + "total_perimeter": middle_slice_result["total_perimeter"], + "thickness_profile": middle_slice_result["thickness_profile"], } - + # multiply split contour with resolution scale factor for middle slice visualization - split_contours = [split_contour * orig.header.get_zooms()[1] for split_contour in middle_slice_result['split_contours']] - if middle_slice_result['split_contours_hofer_frahm'] is not None: - split_contours_hofer_frahm = [split_contour * orig.header.get_zooms()[1] for split_contour in middle_slice_result['split_contours_hofer_frahm']] + split_contours = [ + split_contour * orig.header.get_zooms()[1] for split_contour in middle_slice_result["split_contours"] + ] + if middle_slice_result["split_contours_hofer_frahm"] is not None: + split_contours_hofer_frahm = [ + split_contour * orig.header.get_zooms()[1] + for split_contour in middle_slice_result["split_contours_hofer_frahm"] + ] else: split_contours_hofer_frahm = split_contours # backward compatibility - midline_equidistant = middle_slice_result['midline_equidistant'] * orig.header.get_zooms()[1] - levelpaths = [levelpath * orig.header.get_zooms()[1] for levelpath in middle_slice_result['levelpaths']] - + midline_equidistant = middle_slice_result["midline_equidistant"] * orig.header.get_zooms()[1] + levelpaths = [levelpath * orig.header.get_zooms()[1] for levelpath in middle_slice_result["levelpaths"]] + # Save middle slice visualization single_slice_result = { - 'split_contours': split_contours, - 'split_contours_hofer_frahm': split_contours_hofer_frahm, - 'midline_equidistant': midline_equidistant, - 'levelpaths': levelpaths + "split_contours": split_contours, + "split_contours_hofer_frahm": split_contours_hofer_frahm, + "midline_equidistant": midline_equidistant, + "levelpaths": levelpaths, } - IO_processes.append(create_visualization(subdivision_method, single_slice_result, midslices, - output_dir, ac_coords, pc_coords, orig.header.get_zooms()[0], ' (Middle Slice)')) - + IO_processes.append( + create_visualization( + subdivision_method, + single_slice_result, + midslices, + output_dir, + ac_coords, + pc_coords, + orig.header.get_zooms()[0], + " (Middle Slice)", + ) + ) + # get ac and pc in all spaces ac_coords_3d = np.hstack((FSAVERAGE_MIDDLE, ac_coords)) pc_coords_3d = np.hstack((FSAVERAGE_MIDDLE, pc_coords)) - standardized_to_orig_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig = get_mapping_to_standard_space(orig, ac_coords_3d, pc_coords_3d, orig_fsaverage_vox2vox, output_dir) - + standardized_to_orig_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig = ( + get_mapping_to_standard_space(orig, ac_coords_3d, pc_coords_3d, orig_fsaverage_vox2vox, output_dir) + ) # save segmentation with fitting affine orig_to_seg = np.eye(4) - orig_to_seg[0, 3] = -FSAVERAGE_MIDDLE+slices_to_analyze//2 + orig_to_seg[0, 3] = -FSAVERAGE_MIDDLE + slices_to_analyze // 2 seg_affine = fsaverage_hires_affine seg_affine = seg_affine @ np.linalg.inv(orig_to_seg) save_nifti_background(IO_processes, segmentation, seg_affine, orig.header, segmentation_path) @@ -460,11 +659,15 @@ def main(in_mri_path: str | Path, aseg_path: str | Path, output_dir: str | Path, json.dump(output_dict, f, indent=4) # save lta to fsaverage space - lta.writeLTA(upright_lta_path, orig_fsaverage_ras2ras, aseg_path, aseg_nib.header, 'fsaverage', fsaverage_header) + lta.writeLTA(upright_lta_path, orig_fsaverage_ras2ras, aseg_path, aseg_nib.header, "fsaverage", fsaverage_header) # save lta to standardized space (fsaverage + nodding + ac to center) - orig_to_standardized_ras2ras = orig.affine @ np.linalg.inv(standardized_to_orig_vox2vox) @ np.linalg.inv(orig.affine) - lta.writeLTA(orient_volume_lta_path, orig_to_standardized_ras2ras, in_mri_path, orig.header, in_mri_path, orig.header) + orig_to_standardized_ras2ras = ( + orig.affine @ np.linalg.inv(standardized_to_orig_vox2vox) @ np.linalg.inv(orig.affine) + ) + lta.writeLTA( + orient_volume_lta_path, orig_to_standardized_ras2ras, in_mri_path, orig.header, in_mri_path, orig.header + ) for process in IO_processes: if process is not None: @@ -474,10 +677,10 @@ def main(in_mri_path: str | Path, aseg_path: str | Path, output_dir: str | Path, if __name__ == "__main__": options = options_parse() main_args = vars(options) - + # Rename keys to match main function parameters - main_args['in_mri_path'] = main_args.pop('in_mri') - main_args['aseg_path'] = main_args.pop('aseg') - main_args['output_dir'] = main_args.pop('subject_dir', '.') - + main_args["in_mri_path"] = main_args.pop("in_mri") + main_args["aseg_path"] = main_args.pop("aseg") + main_args["output_dir"] = main_args.pop("subject_dir", ".") + main(**main_args) diff --git a/CorpusCallosum/localization/localization_inference.py b/CorpusCallosum/localization/localization_inference.py index 9119c678..9055c6d4 100644 --- a/CorpusCallosum/localization/localization_inference.py +++ b/CorpusCallosum/localization/localization_inference.py @@ -1,7 +1,5 @@ -import time import torch import numpy as np -import nibabel as nib from monai import transforms from monai.networks.nets import DenseNet as DenseNet_monai @@ -44,11 +42,6 @@ def load_model(checkpoint_path, device=None): else: state_dict = checkpoint_path - - # model = torch.nn.DataParallel(model) - # model.load_state_dict(state_dict) - # model = model.module - # torch.save(model.state_dict(), '/workspace/weights/localization_weights1.pth') model.load_state_dict(state_dict) model = model.to(device) @@ -157,58 +150,10 @@ def run_inference(model, image_volume, third_ventricle_center, device=None, tran outputs[:, 3] += t_dict['crop_top'] - return outputs[:,:2].cpu().numpy(), outputs[:,2:].cpu().numpy(), inputs.cpu().numpy(), (t_dict['crop_left'], t_dict['crop_top']) - -def load_validation_data(path): - import pandas as pd - data = pd.read_csv(path, index_col=0, header=None) - data.columns = ["image", "label", "AC_center_x", "AC_center_y", "AC_center_z", "PC_center_x", "PC_center_y", "PC_center_z"] - - data = data.drop(['15656','5bd8d9b2-e0d3-4a40-b00c-03dfffc5b206'], errors='ignore') - - ac_centers = data[["AC_center_x", "AC_center_y", "AC_center_z"]].values - pc_centers = data[["PC_center_x", "PC_center_y", "PC_center_z"]].values - images = data["image"].values - - label_widths = [] - for label_path in data['label']: - label_img =nib.load(label_path) - - if label_img.shape[0] > 100: - # check which slices have non-zero values - label = label_img.get_fdata() - non_zero_slices = np.any(label > 0, axis=(1,2)) - first_nonzero = np.argmax(non_zero_slices) - last_nonzero = len(non_zero_slices) - np.argmax(non_zero_slices[::-1]) - label_widths.append(last_nonzero - first_nonzero) - else: - label_widths.append(label_img.shape[0]) - - - extended_data = pd.read_csv("/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/network/data/found_labels_with_meta_data_difficult_final.csv", index_col=0) - extended_data = extended_data.loc[data.index] - - third_ventricle_centers = [] - vox_sizes = [] - for aseg_up in extended_data['aseg_up_nocc']: - aseg_up_img = nib.load(aseg_up) - aseg_up_data = aseg_up_img.get_fdata() - - aseg_up_mid = aseg_up_data.shape[0] // 2 - - tv_center = np.mean(np.argwhere(aseg_up_data == 14), axis=0)[1:] - - if np.isnan(tv_center).any(): - import pdb; pdb.set_trace() - - third_ventricle_centers.append(tv_center) - vox_sizes.append(np.prod(aseg_up_img.header.get_zooms()[1])) - - - subj_ids = data.index.values - - return images, ac_centers, pc_centers, label_widths, third_ventricle_centers, vox_sizes, subj_ids - + return (outputs[:,:2].cpu().numpy(), + outputs[:,2:].cpu().numpy(), + inputs.cpu().numpy(), + (t_dict['crop_left'], t_dict['crop_top'])) def run_inference_on_slice(model, image_slice, center_pt, debug_output=None): @@ -235,139 +180,3 @@ def run_inference_on_slice(model, image_slice, center_pt, debug_output=None): return ac_coords, pc_coords - - - - -# TODO: add check if the prediction of first and second round diverges too much - -def run_validation(): - from matplotlib import pyplot as plt - from matplotlib.patches import Rectangle - - # Load model - #model_path = "/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/network/experiments_localization_2/finetune_03_fixweights/checkpoints/best_metric_model.pth" - model_path = '/workspace/weights/localization_weights_acpc.pth' - - model = load_model(model_path) - - # Load a test image slice - #test_img = nib.load("/groups/ag-reuter/projects/corpus_callosum_fornix/label_QC/added_images/48e2d11f/orig_up.mgz") - - val_images, val_ac, val_pc, val_label_widths, val_third_ventricle_centers, val_vox_sizes, val_subj_ids = load_validation_data("/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/network/data/test_joined_labels.csv") - - dist_out = [] - dist_out_dict = {} - uncertainty_out_dict = {} - for img_path, AC_center, PC_center, label_width, third_ventricle_center, vox_size, subj_id in zip(val_images, val_ac, val_pc, val_label_widths, val_third_ventricle_centers, val_vox_sizes, val_subj_ids): - - # if subj_id != '1ca3a723-d981-4bbd-ae97-3f1f03ce5f0e': - # continue - - test_img = nib.load(img_path) - test_slice = test_img.get_fdata() - - #label_width = 13 - - - - # crop to middle 3+-1 (13) slices - test_slice = test_slice[256//2-label_width//2-1:256//2+label_width//2+2] - - # Run inference - start_time = time.time() - ac_coords, pc_coords, inputs, (crop_left, crop_top) = run_inference(model, test_slice, third_ventricle_center) - center_pt = np.mean(np.concatenate([ac_coords, pc_coords], axis=0), axis=0) - ac_coords, pc_coords, inputs, (crop_left, crop_top) = run_inference(model, test_slice, center_pt) - - - inference_time = time.time() - start_time - print(f"Inference took {inference_time:.3f} seconds") - - ac_dist = np.linalg.norm(AC_center[1:] - np.mean(ac_coords, axis=0)) / vox_size - pc_dist = np.linalg.norm(PC_center[1:] - np.mean(pc_coords, axis=0)) / vox_size - # ac_dist = np.linalg.norm(AC_center[1:] - ac_coords[ac_coords.shape[0]//2]) / vox_size - # pc_dist = np.linalg.norm(PC_center[1:] - pc_coords[pc_coords.shape[0]//2]) / vox_size - dist_out.append([ac_dist, pc_dist]) - dist_out_dict[subj_id] = [ac_dist, pc_dist] - - print(f"Distance AC: {ac_dist:.4f}, PC: {pc_dist:.4f}") - - - # fig, ax = plt.subplots(1, 1, figsize=(10, 8)) - # # Original image views - # #ax.imshow(inputs[inputs.shape[0]//2, 1], cmap='gray') - # ax.imshow(test_slice[test_slice.shape[0]//2, :, :], cmap='gray') - # # Plot points on all views - # pc_coords_plot = np.mean(pc_coords, axis=0) - # ac_coords_plot = np.mean(ac_coords, axis=0) - # ax.scatter(PC_center[2], PC_center[1], c='g', marker='o', label='Pred PC', s=2, alpha=0.5) - # ax.scatter(AC_center[2], AC_center[1], c='y', marker='o', label='Pred AC', s=2, alpha=0.5) - # ax.scatter(pc_coords_plot[1], pc_coords_plot[0], c='r', marker='x', label='PC', s=2, alpha=0.5) - # ax.scatter(ac_coords_plot[1], ac_coords_plot[0], c='b', marker='x', label='AC', s=2, alpha=0.5) - - # for i in range(len(pc_coords)): - # ax.scatter(pc_coords[i][1], pc_coords[i][0], c='orange', marker='x', label='PC', s=2, alpha=0.5) - # ax.scatter(ac_coords[i][1], ac_coords[i][0], c='purple', marker='x', label='AC', s=2, alpha=0.5) - - # # make a box where the crop is - # ax.add_patch(Rectangle((crop_top, crop_left), 64, 64, fill=False, color='r', linewidth=2)) - # plt.savefig(f"/workspace/outputs/slice.png", bbox_inches='tight', dpi=500) - # plt.close() - - # print(np.linalg.norm(PC_center[1:] - pc_coords, axis=1)) - # print(np.linalg.norm(AC_center[1:] - ac_coords, axis=1)) - - # fig, ax = plt.subplots(1, 1, figsize=(10, 8)) - # plt.plot(np.linalg.norm(PC_center[1:] - pc_coords, axis=1), color='r') - # plt.plot(np.linalg.norm(AC_center[1:] - ac_coords, axis=1), color='b') - # plt.hlines([np.linalg.norm(PC_center[1:] - pc_coords[pc_coords.shape[0]//2])], 0, len(np.linalg.norm(PC_center[1:] - pc_coords, axis=1)), color='r', linestyle='--') - # plt.hlines([np.linalg.norm(AC_center[1:] - ac_coords[ac_coords.shape[0]//2])], 0, len(np.linalg.norm(AC_center[1:] - ac_coords, axis=1)), color='b', linestyle='--') - # plt.savefig(f"/workspace/outputs/slice_pred_dist.png", bbox_inches='tight') - # plt.close() - - - # print('Uncertainty PC: ', np.linalg.norm(pc_coords - pc_coords[pc_coords.shape[0]//2])) - # print('Uncertainty AC: ', np.linalg.norm(ac_coords - ac_coords[ac_coords.shape[0]//2])) - # uncertainty_out_dict[subj_id] = [np.linalg.norm(pc_coords - pc_coords[pc_coords.shape[0]//2]), np.linalg.norm(ac_coords - ac_coords[ac_coords.shape[0]//2])] - - - - #import pdb; pdb.set_trace() - - - # if len(dist_out_dict) == 3: - # break - - - - import pandas as pd - dist_out_df = pd.DataFrame.from_dict(dist_out_dict, orient='index', columns=['ac_dist', 'pc_dist']) - dist_out_df.to_csv("/workspace/outputs/dist_out_dict.csv") - - uncertainty_out_df = pd.DataFrame.from_dict(uncertainty_out_dict, orient='index', columns=['pc_uncertainty', 'ac_uncertainty']) - uncertainty_out_df.to_csv("/workspace/outputs/uncertainty_localization_out_dict.csv") - - - # Convert numpy array to NIfTI image before saving - #nifti_img_in = nib.Nifti1Image(inputs, affine=test_img.affine, header=test_img.header) - #nifti_orig_slice = nib.Nifti1Image(test_slice[4:-4], affine=test_img.affine, header=test_img.header) - #nib.save(nifti_img_in, "/workspace/outputs/segmentation_input.nii.gz") - #nib.save(nifti_orig_slice, "/workspace/outputs/segmentation_orig.nii.gz") - - - - print(f'Overall error - AC: {np.mean(dist_out, axis=0)[0]:.4f} mm, PC: {np.mean(dist_out, axis=0)[1]:.4f} mm') - - - # validation set, middle 2x AC: 0.7648 mm, PC: 0.8181 mm - # validation set, mean 2x AC: 0.7638 mm, PC: 0.8404 mm --- chose mean - - # test set (mean 2x) AC: 0.9004 mm, PC: 0.9482 mm - - # diificult set (mean 2x): AC: 0.9179 mm, PC: 1.3477 mm - - -# Example usage: -if __name__ == "__main__": - run_validation() \ No newline at end of file diff --git a/CorpusCallosum/registration/mapping_helpers.py b/CorpusCallosum/registration/mapping_helpers.py index ed209d81..80312a4d 100644 --- a/CorpusCallosum/registration/mapping_helpers.py +++ b/CorpusCallosum/registration/mapping_helpers.py @@ -1,83 +1,84 @@ -from pathlib import Path import numpy as np import nibabel as nib -import matplotlib.pyplot as plt from scipy.ndimage import affine_transform def make_midplane_affine(orig_affine, slices_to_analyze=1, offset=4): """ Creates an affine transformation matrix for midplane slices. - + Args: orig_affine: Original image affine matrix slices_to_analyze: Number of slices to analyze around midplane (default=1) offset: Additional offset in x direction (default=4) - + Returns: seg_affine: Affine matrix for midplane slices """ # Create translation matrix to center on midplane orig_to_seg = np.eye(4) - orig_to_seg[0, 3] = -256//2 + slices_to_analyze//2 + offset - + orig_to_seg[0, 3] = -256 // 2 + slices_to_analyze // 2 + offset + # Combine with original affine seg_affine = orig_affine @ np.linalg.inv(orig_to_seg) - + return seg_affine def correct_nodding(ac_pt, pc_pt): """ Calculates rotation matrix to correct for head nodding based on AC-PC line orientation. - + Args: ac_pt: Coordinates of the anterior commissure point pc_pt: Coordinates of the posterior commissure point - + Returns: rotation_matrix: 3x3 rotation matrix to align AC-PC line with posterior direction """ - ac_pc_vec = pc_pt - ac_pt + ac_pc_vec = pc_pt - ac_pt ac_pc_dist = np.linalg.norm(ac_pc_vec) - + posterior_vector = np.array([0, -ac_pc_dist]) - + # get angle between ac_pc_vec and posterior_vector dot_product = np.dot(ac_pc_vec, posterior_vector) norms_product = np.linalg.norm(ac_pc_vec) * np.linalg.norm(posterior_vector) theta = np.arccos(dot_product / norms_product) - + # Determine the sign of the angle using cross product cross_product = np.cross(ac_pc_vec, posterior_vector) if cross_product < 0: theta = -theta - + # create rotation matrix for theta - rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], - [np.sin(theta), np.cos(theta), 0], - [0, 0, 1]]) - + rotation_matrix = np.array( + [ + [np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1], + ] + ) + # plot vector ac_pc_vec and posterior_vector # fig, ax = plt.subplots() # ax.quiver(0, 0, ac_pc_vec[0], ac_pc_vec[1], color='red', label='ac_pc_vec') # ax.quiver(0, 0, posterior_vector[0], posterior_vector[1], color='blue', label='posterior_vector') # ax.legend() # plt.show() - - return rotation_matrix + return rotation_matrix def apply_transform_to_pt(pts, T, inv=False): """ Applies an homoegenous 4x4 transformation matrix to a point. - + Args: pts: Point coordinates to transform T: Transformation matrix inv: If True, applies inverse of transformation (default=False) - + Returns: Transformed point coordinates """ @@ -86,21 +87,24 @@ def apply_transform_to_pt(pts, T, inv=False): T = np.linalg.inv(T) if pts.ndim == 1: - return (T @ np.hstack((pts,1)))[:3] + return (T @ np.hstack((pts, 1)))[:3] else: - return (T @ np.concatenate([pts,np.ones((1,pts.shape[1]))]))[:3] + return (T @ np.concatenate([pts, np.ones((1, pts.shape[1]))]))[:3] + -def get_mapping_to_standard_space(orig, ac_coords_3d, pc_coords_3d, orig_fsaverage_vox2vox, output_dir): +def get_mapping_to_standard_space( + orig, ac_coords_3d, pc_coords_3d, orig_fsaverage_vox2vox, output_dir +): """ Maps an image to standard space using AC-PC alignment. - + Args: orig: Original image ac_coords_3d: 3D coordinates of anterior commissure pc_coords_3d: 3D coordinates of posterior commissure orig_fsaverage_vox2vox: Original to fsaverage space transformation matrix output_dir: Directory for output files - + Returns: tuple: (transformation matrix, AC coords standardized, PC coords standardized, AC coords original, PC coords original) @@ -112,49 +116,79 @@ def get_mapping_to_standard_space(orig, ac_coords_3d, pc_coords_3d, orig_fsavera # convert 2D nodding correction to 3D transformation matrix nod_correct_3d = np.eye(4) - nod_correct_3d[1:3,1:3] = nod_correct_2d[:2,:2] # Copy rotation part to y,z axes - nod_correct_3d[1:3,3] = nod_correct_2d[:2,2] # Copy translation part to y,z axes (usually no translation) - - - ac_coords_after_nodding = apply_transform_to_pt(ac_coords_3d, nod_correct_3d, inv=False) - pc_coords_after_nodding = apply_transform_to_pt(pc_coords_3d, nod_correct_3d, inv=False) + nod_correct_3d[1:3, 1:3] = nod_correct_2d[:2, :2] # Copy rotation part to y,z axes + nod_correct_3d[1:3, 3] = nod_correct_2d[ + :2, 2 + ] # Copy translation part to y,z axes (usually no translation) + + ac_coords_after_nodding = apply_transform_to_pt( + ac_coords_3d, nod_correct_3d, inv=False + ) + pc_coords_after_nodding = apply_transform_to_pt( + pc_coords_3d, nod_correct_3d, inv=False + ) ac_to_center_translation = np.eye(4) - ac_to_center_translation[0,3] = image_center[0] - ac_coords_after_nodding[0] - ac_to_center_translation[1,3] = image_center[1] - ac_coords_after_nodding[1] - ac_to_center_translation[2,3] = image_center[2] - ac_coords_after_nodding[2] - + ac_to_center_translation[0, 3] = image_center[0] - ac_coords_after_nodding[0] + ac_to_center_translation[1, 3] = image_center[1] - ac_coords_after_nodding[1] + ac_to_center_translation[2, 3] = image_center[2] - ac_coords_after_nodding[2] + # correct nodding - ac_coords_standardized = apply_transform_to_pt(ac_coords_after_nodding, ac_to_center_translation, inv=False) - pc_coords_standardized = apply_transform_to_pt(pc_coords_after_nodding, ac_to_center_translation, inv=False) + ac_coords_standardized = apply_transform_to_pt( + ac_coords_after_nodding, ac_to_center_translation, inv=False + ) + pc_coords_standardized = apply_transform_to_pt( + pc_coords_after_nodding, ac_to_center_translation, inv=False + ) + + standardized_to_orig_vox2vox = ( + np.linalg.inv(orig_fsaverage_vox2vox) + @ np.linalg.inv(nod_correct_3d) + @ np.linalg.inv(ac_to_center_translation) + ) - standardized_to_orig_vox2vox = np.linalg.inv(orig_fsaverage_vox2vox) @ np.linalg.inv(nod_correct_3d) @ np.linalg.inv(ac_to_center_translation) - # calculate ac & pc in space of mri input image - ac_coords_orig = apply_transform_to_pt(ac_coords_standardized, standardized_to_orig_vox2vox, inv=False) - pc_coords_orig = apply_transform_to_pt(pc_coords_standardized, standardized_to_orig_vox2vox, inv=False) - - return standardized_to_orig_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig - - -def apply_transform_and_map_volume(volume, transform, affine, header, output_path=None, order=3, output_size=None): + ac_coords_orig = apply_transform_to_pt( + ac_coords_standardized, standardized_to_orig_vox2vox, inv=False + ) + pc_coords_orig = apply_transform_to_pt( + pc_coords_standardized, standardized_to_orig_vox2vox, inv=False + ) + + return ( + standardized_to_orig_vox2vox, + ac_coords_standardized, + pc_coords_standardized, + ac_coords_orig, + pc_coords_orig, + ) + + +def apply_transform_and_map_volume( + volume, transform, affine, header, output_path=None, order=3, output_size=None +): """ Applies transformation to a volume and saves the result. - + Args: volume: Input volume data transform: Transformation matrix to apply affine: Affine matrix for the output image header: Header for the output image output_path: Path to save transformed volume - + Returns: transformed: Transformed volume data """ if output_size is None: output_size = np.array(volume.shape) - transformed = affine_transform(volume.astype(np.float32), np.linalg.inv(transform), output_shape=output_size, order=order) + transformed = affine_transform( + volume.astype(np.float32), + np.linalg.inv(transform), + output_shape=output_size, + order=order, + ) if output_path is not None: nib.save(nib.MGHImage(transformed, affine, header), output_path) return transformed @@ -163,44 +197,48 @@ def apply_transform_and_map_volume(volume, transform, affine, header, output_pat def make_affine(simpleITKImage): """ Creates an affine transformation matrix from a SimpleITK image. - + Args: simpleITKImage: Input SimpleITK image - + Returns: affine: 4x4 affine transformation matrix in RAS coordinates """ # get affine transform in LPS - c = [simpleITKImage.TransformContinuousIndexToPhysicalPoint(p) - for p in ((1, 0, 0), - (0, 1, 0), - (0, 0, 1), - (0, 0, 0))] + c = [ + simpleITKImage.TransformContinuousIndexToPhysicalPoint(p) + for p in ((1, 0, 0), (0, 1, 0), (0, 0, 1), (0, 0, 0)) + ] c = np.array(c) - affine = np.concatenate([ - np.concatenate([c[0:3] - c[3:], c[3:]], axis=0), - [[0.], [0.], [0.], [1.]] - ], axis=1) + affine = np.concatenate( + [np.concatenate([c[0:3] - c[3:], c[3:]], axis=0), [[0.0], [0.0], [0.0], [1.0]]], + axis=1, + ) affine = np.transpose(affine) # convert to RAS to match nibabel - affine = np.matmul(np.diag([-1., -1., 1., 1.]), affine) + affine = np.matmul(np.diag([-1.0, -1.0, 1.0, 1.0]), affine) return affine - - -def map_softlabels_to_orig(outputs_soft, orig_fsaverage_vox2vox, orig, slices_to_analyze, orig_space_segmentation_path = None, fsaverage_middle=128): +def map_softlabels_to_orig( + outputs_soft, + orig_fsaverage_vox2vox, + orig, + slices_to_analyze, + orig_space_segmentation_path=None, + fsaverage_middle=128, +): """ Maps soft labels back to original image space and applies post-processing. # TODO: this could by padding after the transform - + Args: outputs_soft: Soft label predictions orig_fsaverage_vox2vox: Original to fsaverage space transformation orig: Original image slices_to_analyze: Number of slices to analyze - + Returns: segmentation_orig_space: Final segmentation in original image space """ @@ -210,76 +248,92 @@ def map_softlabels_to_orig(outputs_soft, orig_fsaverage_vox2vox, orig, slices_to # pad to original image size outputs_soft_padded = np.zeros(orig.shape) - outputs_soft_padded[fsaverage_middle-slices_to_analyze//2:fsaverage_middle+slices_to_analyze//2+1] = outputs_soft[...,i] + outputs_soft_padded[ + fsaverage_middle + - slices_to_analyze // 2 : fsaverage_middle + + slices_to_analyze // 2 + + 1 + ] = outputs_soft[..., i] s = affine_transform( outputs_soft_padded, orig_fsaverage_vox2vox, output_shape=orig.shape, order=1, - cval=1.0 if i == 0 else 0.0 + cval=1.0 if i == 0 else 0.0, ) softlabels_transformed.append(s) softlabels_orig_space = np.stack(softlabels_transformed, axis=-1) - # nib.save(nib.MGHImage(outputs_soft, seg_affine, transformed_img.header), Path(output_dir) / "softlabels_seg_space.mgz") - # nib.save(nib.MGHImage(softlabels_orig_space, orig.affine, orig.header), Path(output_dir) / "softlabels_orig_space.mgz") - # apply softmax to softlabels_orig_space - softlabels_orig_space = np.exp(softlabels_orig_space) / np.sum(np.exp(softlabels_orig_space), axis=-1, keepdims=True) + softlabels_orig_space = np.exp(softlabels_orig_space) / np.sum( + np.exp(softlabels_orig_space), axis=-1, keepdims=True + ) segmentation_orig_space = np.argmax(softlabels_orig_space, axis=-1) - segmentation_orig_space = np.where(segmentation_orig_space == 1, 192, segmentation_orig_space) - segmentation_orig_space = np.where(segmentation_orig_space == 2, 250, segmentation_orig_space) + segmentation_orig_space = np.where( + segmentation_orig_space == 1, 192, segmentation_orig_space + ) + segmentation_orig_space = np.where( + segmentation_orig_space == 2, 250, segmentation_orig_space + ) if orig_space_segmentation_path is not None: - nib.save(nib.MGHImage(segmentation_orig_space, orig.affine, orig.header), orig_space_segmentation_path) + nib.save( + nib.MGHImage(segmentation_orig_space, orig.affine, orig.header), + orig_space_segmentation_path, + ) return segmentation_orig_space + def interpolate_midplane(orig, orig_fsaverage_vox2vox, slices_to_analyze): """ Interpolates image data at the midplane using a grid of points. - + Args: orig: Original image orig_fsaverage_vox2vox: Original to fsaverage space transformation slices_to_analyze: Number of slices to analyze - + Returns: transformed: Interpolated image data at midplane """ - #slice_thickness = 9+slices_to_analyze-1 - # make grid of 9 slices in the fsaverage middle (cube from 123.5,0.5,0.5 to 132.5,255.5,255.5 (incudling end points, 1mm spacing)) - x_coords = np.linspace(124-slices_to_analyze//2, 132+slices_to_analyze//2, 9+(slices_to_analyze-1), endpoint=True) # 9 points from 123.5 to 132.5 - #x_coords = np.linspace(orig.shape[0]//2-slice_thickness//2, orig.shape[0]//2+slice_thickness//2, slice_thickness, endpoint=True) - y_coords = np.linspace(0, orig.shape[1]-1, orig.shape[1], endpoint=True) # 255 points from 0.5 to 255.5 - z_coords = np.linspace(0, orig.shape[2]-1, orig.shape[2], endpoint=True) # 255 points from 0.5 to 255.5 - X, Y, Z = np.meshgrid(x_coords, y_coords, z_coords, indexing='ij') - + # slice_thickness = 9+slices_to_analyze-1 + # make grid of 9 slices in the fsaverage middle + # (cube from 123.5,0.5,0.5 to 132.5,255.5,255.5 (incudling end points, 1mm spacing)) + x_coords = np.linspace( + 124 - slices_to_analyze // 2, + 132 + slices_to_analyze // 2, + 9 + (slices_to_analyze - 1), + endpoint=True, + ) # 9 points from 123.5 to 132.5 + y_coords = np.linspace( + 0, orig.shape[1] - 1, orig.shape[1], endpoint=True + ) # 255 points from 0.5 to 255.5 + z_coords = np.linspace( + 0, orig.shape[2] - 1, orig.shape[2], endpoint=True + ) # 255 points from 0.5 to 255.5 + X, Y, Z = np.meshgrid(x_coords, y_coords, z_coords, indexing="ij") + # Stack coordinates and add homogeneous coordinate - grid_fsaverage = np.stack([ - X.ravel(), - Y.ravel(), - Z.ravel(), - np.ones(X.size) - ]) - + grid_fsaverage = np.stack([X.ravel(), Y.ravel(), Z.ravel(), np.ones(X.size)]) + # move grid to orig space by applying transform grid_orig = np.linalg.inv(orig_fsaverage_vox2vox) @ grid_fsaverage - + # interpolate grid on orig image from scipy.ndimage import map_coordinates + transformed = map_coordinates( orig.get_fdata(), - grid_orig[0:3,:], # use only x,y,z coordinates (drop homogeneous coordinate) + grid_orig[0:3, :], # use only x,y,z coordinates (drop homogeneous coordinate) order=2, - mode='constant', + mode="constant", cval=0, - prefilter=True + prefilter=True, ).reshape(len(x_coords), len(y_coords), len(z_coords)) return transformed - diff --git a/CorpusCallosum/segmentation/segmentation_inference.py b/CorpusCallosum/segmentation/segmentation_inference.py index 21316d12..22113569 100644 --- a/CorpusCallosum/segmentation/segmentation_inference.py +++ b/CorpusCallosum/segmentation/segmentation_inference.py @@ -1,13 +1,11 @@ -import time import torch import numpy as np import nibabel as nib from monai import transforms -from monai.metrics import DiceMetric, HausdorffDistanceMetric from FastSurferCNN.models.networks import FastSurferVINN -from transforms.segmentation_transforms import CropAroundACPC, UncropAroundACPC +from transforms.segmentation_transforms import CropAroundACPC def load_model(checkpoint_path, device=None): @@ -74,7 +72,10 @@ def run_inference(model, image_slice, AC_center, PC_center, voxel_size, device=N #device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = next(model.parameters()).device - crop_around_acpc = lambda img, ac, pc, vox_size: CropAroundACPC(keys=['image'], padding_mm=35, random_translate=0)({'image': img, 'AC_center': ac, 'PC_center': pc, 'res': vox_size}) + def crop_around_acpc(img, ac, pc, vox_size): + return CropAroundACPC(keys=['image'], padding_mm=35, random_translate=0)( + {'image': img, 'AC_center': ac, 'PC_center': pc, 'res': vox_size} + ) # Preprocess slice inputs = torch.from_numpy(image_slice[:,None,:256,:256]) # artifact from training script @@ -113,9 +114,12 @@ def run_inference(model, image_slice, AC_center, PC_center, voxel_size, device=N pad_left, pad_right, pad_top, pad_bottom = to_pad # Pad back to original size - outputs = np.pad(outputs, ((0,0), (0,0), (pad_left.item(), pad_right.item()), (pad_top.item(), pad_bottom.item())), mode='constant', constant_values=0) - outputs_avg = np.pad(outputs_avg, ((0,0), (0,0), (pad_left.item(), pad_right.item()), (pad_top.item(), pad_bottom.item())), mode='constant', constant_values=0) - outputs_soft = np.pad(outputs_soft, ((0,0), (0,0), (pad_left.item(), pad_right.item()), (pad_top.item(), pad_bottom.item())), mode='constant', constant_values=0) + outputs = np.pad(outputs, ((0,0), (0,0), (pad_left.item(), pad_right.item()), + (pad_top.item(), pad_bottom.item())), mode='constant', constant_values=0) + outputs_avg = np.pad(outputs_avg, ((0,0), (0,0), (pad_left.item(), pad_right.item()), + (pad_top.item(), pad_bottom.item())), mode='constant', constant_values=0) + outputs_soft = np.pad(outputs_soft, ((0,0), (0,0), (pad_left.item(), pad_right.item()), + (pad_top.item(), pad_bottom.item())), mode='constant', constant_values=0) # restore original shape if orig_shape[-2:] != outputs.shape[-2:]: @@ -127,11 +131,17 @@ def run_inference(model, image_slice, AC_center, PC_center, voxel_size, device=N new_outputs_avg[:,:,:256,:256] = outputs_avg outputs_avg = new_outputs_avg - new_outputs_soft = np.zeros((outputs_soft.shape[0], outputs_soft.shape[1], orig_shape[-2], orig_shape[-1]), dtype=np.float32) + new_outputs_soft = np.zeros((outputs_soft.shape[0], outputs_soft.shape[1], + orig_shape[-2], orig_shape[-1]), dtype=np.float32) new_outputs_soft[:,:,:256,:256] = outputs_soft outputs_soft = new_outputs_soft - return outputs.transpose(0,2,3,1), inputs.cpu().numpy().transpose(0,2,3,1), outputs_avg.transpose(0,2,3,1), outputs_soft.transpose(0,2,3,1) + return ( + outputs.transpose(0,2,3,1), + inputs.cpu().numpy().transpose(0,2,3,1), + outputs_avg.transpose(0,2,3,1), + outputs_soft.transpose(0,2,3,1), + ) # TODO: load validation data and run inference on it to confirm correct processing @@ -139,7 +149,8 @@ def run_inference(model, image_slice, AC_center, PC_center, voxel_size, device=N def load_validation_data(path): import pandas as pd data = pd.read_csv(path, index_col=0, header=None) - data.columns = ["image", "label", "AC_center_x", "AC_center_y", "AC_center_z", "PC_center_x", "PC_center_y", "PC_center_z"] + data.columns = ["image", "label", "AC_center_x", "AC_center_y", "AC_center_z", + "PC_center_x", "PC_center_y", "PC_center_z"] ac_centers = data[["AC_center_x", "AC_center_y", "AC_center_z"]].values pc_centers = data[["PC_center_x", "PC_center_y", "PC_center_z"]].values @@ -166,7 +177,9 @@ def load_validation_data(path): return images, ac_centers, pc_centers, label_widths, labels, subj_ids -def one_hot_to_label(one_hot, label_ids=[0,192,250]): +def one_hot_to_label(one_hot, label_ids=None): + if label_ids is None: + label_ids = [0, 192, 250] label = np.argmax(one_hot, axis=3) if label_ids is not None: label = np.where(label == 0, label_ids[0], label) @@ -241,177 +254,3 @@ def remove_small_clusters(label_data, min_cluster_size=100): return np.stack([label_data[:,0]]+list_of_cleaned_labels, axis=1) - - -def run_validation(): - - # Load model - - model_path = "/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/cc_pipeline/weights/segmentation_weights_cc_fn.pth" - # /groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/network/experiments/CCFN_softmax01/checkpoints/best_metric_model.pth - - model = load_model(model_path) - - # Load a test image slice - #test_img = nib.load("/groups/ag-reuter/projects/corpus_callosum_fornix/label_QC/added_images/48e2d11f/orig_up.mgz") - - val_images, val_ac, val_pc, val_label_widths, val_labels, val_subj_ids = load_validation_data("/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/network/data/difficult_joined_labels.csv") - - dice_out = [] - dice_out_single_slice = [] - dice_out_dict = {} - dice_out_single_slice_dict = {} - - # Initialize Hausdorff distance metric - hd_out = [] - hd_out_single_slice = [] - hd_out_dict = {} - hd_out_single_slice_dict = {} - - for img_path, AC_center, PC_center, label_width, label_path, subj_id in zip(val_images, val_ac, val_pc, val_label_widths, val_labels, val_subj_ids): - - # if subj_id != "abf05659": - # continue - - - label_width = 5 - - test_img = nib.load(img_path) - test_slice = test_img.get_fdata() - - # crop to middle 9+5-1 (13) slices - test_slice = test_slice[256//2-label_width//2-4:256//2+label_width//2+5] - - - # Run inference - start_time = time.time() - results, inputs, outputs_avg, outputs_soft = run_inference(model, test_slice, AC_center, PC_center, voxel_size=test_img.header.get_zooms()[0]) - inference_time = time.time() - start_time - print(f"Inference took {inference_time:.3f} seconds") - - label_img = nib.load(label_path) - label = label_img.get_fdata() - - # calculate dice score - dice_metric = DiceMetric(include_background=False, reduction="mean") - hd_metric = HausdorffDistanceMetric(include_background=False, percentile=95.0, reduction="mean") - - # Convert label to one-hot format - label_tensor = torch.from_numpy(label) - - if label_tensor.shape[0] > 100: - # select non-zero slices - label_tensor = label_tensor[label_tensor.any(axis=(1,2))] - - # crop to label width - label_tensor = label_tensor[label_tensor.shape[0]//2-label_width//2:label_tensor.shape[0]//2+label_width//2+1] - - # map to 0,1,2 - ids = np.unique(label) - label_tensor = torch.where(label_tensor == ids[0], 0, label_tensor) - label_tensor = torch.where(label_tensor == ids[1], 1, label_tensor) - label_tensor = torch.where(label_tensor == ids[2], 2, label_tensor) - - label_onehot = torch.nn.functional.one_hot(label_tensor.long(), num_classes=3) # Convert to one-hot with 3 classes - label_onehot = label_onehot.permute(0, 3, 1, 2) # Move class dimension to second position (B,C, H, W) - #label_onehot = label_onehot[:,:,:256,:256] - - # Reshape results to (B, C, H, W) - results_tensor = torch.from_numpy(results) - results_tensor = results_tensor.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) - - # Remove small clusters - results_tensor = remove_small_clusters(results_tensor.numpy(), min_cluster_size=100) - results_tensor = torch.from_numpy(results_tensor) - - - - # Calculate Dice score - dice_score = dice_metric(results_tensor, label_onehot) - midslice = results_tensor.shape[0]//2 - dice_single_slice = dice_metric(results_tensor[None,midslice], label_onehot[None,midslice]) - - # Calculate Hausdorff distance - # Get physical spacing from the image header for accurate distance calculation - spacing = test_img.header.get_zooms()[:3] # Get voxel dimensions in mm - if len(spacing) == 3: - # Use only in-plane spacing for 2D slices - spacing_tensor = torch.tensor([spacing[1], spacing[2]], dtype=torch.float32) - else: - spacing_tensor = torch.tensor(spacing, dtype=torch.float32) - - hd_score = hd_metric(results_tensor, label_onehot, spacing=spacing_tensor.numpy().tolist()) - hd_single_slice = hd_metric(results_tensor[None,midslice], label_onehot[None,midslice], spacing=spacing_tensor.numpy().tolist()) - - # Store results - dice_out.append(dice_score.mean(axis=0).numpy().tolist()) - dice_out_single_slice.append(dice_single_slice.numpy().tolist()) - dice_out_dict[subj_id] = dice_score.mean(axis=0).numpy().tolist() - dice_out_single_slice_dict[subj_id] = dice_single_slice.numpy()[0].tolist() - - hd_out.append(hd_score.mean(axis=0).numpy().tolist()) - hd_out_single_slice.append(hd_single_slice.numpy().tolist()) - hd_out_dict[subj_id] = hd_score.mean(axis=0).numpy().tolist() - hd_out_single_slice_dict[subj_id] = hd_single_slice.numpy()[0].tolist() - - print(f"Subject: {subj_id}") - print(f"Dice mean: {[f'{x:.3f}' for x in dice_score.mean(axis=0).numpy().tolist()]}") - print(f"HD95 mean: {[f'{x:.3f}' for x in hd_score.mean(axis=0).numpy().tolist()]} mm") - - - - - # Convert numpy array to NIfTI image before saving - nifti_img_out = nib.Nifti1Image(results, affine=test_img.affine, header=test_img.header) - nifti_img_in = nib.Nifti1Image(inputs, affine=test_img.affine, header=test_img.header) - nifti_orig_slice = nib.Nifti1Image(test_slice[4:-4], affine=test_img.affine, header=test_img.header) - nifti_avg_slice = nib.Nifti1Image(outputs_avg, affine=test_img.affine, header=test_img.header) - nifti_label = nib.Nifti1Image(label, affine=test_img.affine, header=test_img.header) - nifti_final_out = nib.Nifti1Image(one_hot_to_label(results), affine=test_img.affine, header=test_img.header) - nib.save(nifti_img_in, "/workspace/outputs/segmentation_input.nii.gz") - nib.save(nifti_img_out, "/workspace/outputs/segmentation.nii.gz") - nib.save(nifti_orig_slice, "/workspace/outputs/segmentation_orig.nii.gz") - nib.save(nifti_avg_slice, "/workspace/outputs/segmentation_avg.nii.gz") - nib.save(nifti_label, "/workspace/outputs/segmentation_label.nii.gz") - nib.save(nifti_final_out, "/workspace/outputs/segmentation_final.nii.gz") - import shutil - shutil.copy(img_path, "/workspace/outputs/segmentation_orig.mgz") - shutil.copy(label_path, "/workspace/outputs/segmentation_label.mgz") - - - - - - - print(f'Overall Validation Dice: {[f"{x:.3f}" for x in np.mean(dice_out, axis=0).tolist()]}') - print(f'Overall Validation HD95: {[f"{x:.3f}" for x in np.mean(hd_out, axis=0).tolist()]} mm') - - import pandas as pd - # Save Dice scores - dice_out_df = pd.DataFrame.from_dict(dice_out_dict, orient='index', columns=["CC", "FN"]) - dice_single_slice_df = pd.DataFrame.from_dict(dice_out_single_slice_dict, orient='index', columns=["CC", "FN"]) - dice_out_df.to_csv("/workspace/outputs/dice_out.csv") - dice_single_slice_df.to_csv("/workspace/outputs/dice_single_slice.csv") - - # Save Hausdorff distances - hd_out_df = pd.DataFrame.from_dict(hd_out_dict, orient='index', columns=["CC", "FN"]) - hd_single_slice_df = pd.DataFrame.from_dict(hd_out_single_slice_dict, orient='index', columns=["CC", "FN"]) - hd_out_df.to_csv("/workspace/outputs/hd_out.csv") - hd_single_slice_df.to_csv("/workspace/outputs/hd_single_slice.csv") - - # Create a combined metrics dataframe - combined_metrics = pd.DataFrame() - combined_metrics['Dice_CC'] = dice_out_df['CC'] - combined_metrics['Dice_FN'] = dice_out_df['FN'] - combined_metrics['HD95_CC'] = hd_out_df['CC'] - combined_metrics['HD95_FN'] = hd_out_df['FN'] - combined_metrics.to_csv("/workspace/outputs/combined_metrics.csv") - - # Testset: Overall Dice: ['0.957', '0.829'] HD95: ['1.018', '2.799'] - # Testset only 5 slices: Overall Validation Dice: ['0.957', '0.831'] HD95: ['1.025', '2.318'] - # Difficultset: Overall Validation Dice: ['0.944', '0.785'] HD95: ['1.189', '4.080'] - # Difficultset only 5 slices: Overall Validation Dice: ['0.946', '0.784'] HD95: ['1.155', '4.101'] - - -if __name__ == "__main__": - run_validation() \ No newline at end of file diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index ea47b9df..5ffcfde5 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -2,7 +2,7 @@ from scipy import ndimage from skimage.measure import label -from CorpusCallosum.data.constants import * +from CorpusCallosum.data.constants import CC_LABEL, FORNIX_LABEL @@ -48,9 +48,8 @@ def get_cc_volume(desired_width_mm: int, cc_mask: np.ndarray, voxel_size: tuple[ desired_width_vox = int(np.floor(desired_width_vox) + 1) desired_width_vox = desired_width_vox + 1 if desired_width_vox % 2 == 0 else desired_width_vox - assert cc_mask.shape[0] == desired_width_vox, f"CC mask should have {desired_width_vox} voxels, but has {cc_mask.shape[0]}" - - + assert cc_mask.shape[0] == desired_width_vox, (f"CC mask should have {desired_width_vox} voxels, " + f"but has {cc_mask.shape[0]}") left_partial_volume = np.sum(cc_mask[0]) * voxel_volume * fraction_of_voxel_at_edge right_partial_volume = np.sum(cc_mask[-1]) * voxel_volume * fraction_of_voxel_at_edge diff --git a/CorpusCallosum/shape/cc_endpoint_heuristic.py b/CorpusCallosum/shape/cc_endpoint_heuristic.py index 43963a9e..2186e581 100644 --- a/CorpusCallosum/shape/cc_endpoint_heuristic.py +++ b/CorpusCallosum/shape/cc_endpoint_heuristic.py @@ -1,13 +1,13 @@ -import nibabel as nib import numpy as np import skimage.measure import scipy.ndimage -import pandas as pd -from shape.resample_poly import iterative_resample_polygon +import lapy + def get_endpoints(cc_mask, AC_2d, PC_2d, resolution, return_coordinates=True, contour_smoothing=1.0): """ - Determines endpoints of CC by finding the point in the contour closest to the anterior and posterior commisure (with some offsets) + Determines endpoints of CC by finding the point in the contour closest to + the anterior and posterior commisure (with some offsets) NOTE: Expects LIA orientation """ @@ -20,34 +20,31 @@ def get_endpoints(cc_mask, AC_2d, PC_2d, resolution, return_coordinates=True, co dot_product = np.dot(ac_pc_vector, horizontal_vector) norms = np.linalg.norm(ac_pc_vector) * np.linalg.norm(horizontal_vector) theta = np.arccos(dot_product / norms) - # Convert symbolic theta to float and convert from radians to degrees theta_degrees = theta * 180 / np.pi rotated_cc_mask = scipy.ndimage.rotate(cc_mask, -theta_degrees, order=0, reshape=False) - # rotate points around center - origin_point = np.array([image_size[0]//2, image_size[1]//2]) - + origin_point = np.array([image_size[0] // 2, image_size[1] // 2]) + # Create rotation matrix for -theta - rot_matrix = np.array([[np.cos(-theta), -np.sin(-theta)], - [np.sin(-theta), np.cos(-theta)]]) - + rot_matrix = np.array([[np.cos(-theta), -np.sin(-theta)], [np.sin(-theta), np.cos(-theta)]]) + # Translate points to origin, rotate, then translate back pc_centered = PC_2d - origin_point ac_centered = AC_2d - origin_point - + rotated_PC_2d = (rot_matrix @ pc_centered) + origin_point rotated_AC_2d = (rot_matrix @ ac_centered) + origin_point # get contour of CC gaussian_cc_mask = scipy.ndimage.gaussian_filter(rotated_cc_mask.astype(float), sigma=contour_smoothing) - #gaussian_cc_mask = scipy.ndimage.gaussian_filter(gaussian_cc_mask, sigma=1.0) + # gaussian_cc_mask = scipy.ndimage.gaussian_filter(gaussian_cc_mask, sigma=1.0) contour = skimage.measure.find_contours(gaussian_cc_mask, level=0.5)[0].T - contour = iterative_resample_polygon(contour.T, 701).T - contour = contour[:,:-1] + contour = lapy.tria_mesh.TriaMesh.iterative_resample_polygon(contour.T, 701).T + contour = contour[:, :-1] rotated_AC_2d = np.array(rotated_AC_2d).astype(float) rotated_PC_2d = np.array(rotated_PC_2d).astype(float) @@ -59,32 +56,29 @@ def get_endpoints(cc_mask, AC_2d, PC_2d, resolution, return_coordinates=True, co rotated_AC_2d = rotated_AC_2d + np.array([0, 5 * resolution]) # find point in contour closest to AC - AC_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_AC_2d[:,None], axis=0)) - + AC_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_AC_2d[:, None], axis=0)) + # find point in contour closest to PC - PC_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_PC_2d[:,None], axis=0)) + PC_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_PC_2d[:, None], axis=0)) # rotate startpoints to original orientation # Create rotation matrix - rot_matrix = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) + rot_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) # rotate contour to original orientation contour_rotated = np.zeros_like(contour) origin_point = np.array(origin_point).astype(float) # Create rotation matrix - rot_matrix = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) - - # Translate points to origin, rotate, then translate back - contour_centered = contour - origin_point[:,None] - contour_rotated = (rot_matrix @ contour_centered) + origin_point[:,None] + rot_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + # Translate points to origin, rotate, then translate back + contour_centered = contour - origin_point[:, None] + contour_rotated = (rot_matrix @ contour_centered) + origin_point[:, None] if return_coordinates: - AC_contour_point = contour[:,AC_startpoint_idx] - PC_contour_point = contour[:,PC_startpoint_idx] + AC_contour_point = contour[:, AC_startpoint_idx] + PC_contour_point = contour[:, PC_startpoint_idx] # Translate points to origin, rotate, then translate back ac_centered = AC_contour_point - origin_point @@ -100,11 +94,10 @@ def get_endpoints(cc_mask, AC_2d, PC_2d, resolution, return_coordinates=True, co def get_endpoints_from_nib(cc_label_nib, paths_csv, subj_id, return_coordinates=True): cc_mask = cc_label_nib.get_fdata() == 192 - cc_mask = cc_mask[cc_mask.shape[0]//2] + cc_mask = cc_mask[cc_mask.shape[0] // 2] - - posterior_commisure_center = paths_csv.loc[subj_id, 'PC_center_r':'PC_center_s'].to_numpy().astype(float) - anterior_commisure_center = paths_csv.loc[subj_id, 'AC_center_r':'AC_center_s'].to_numpy().astype(float) + posterior_commisure_center = paths_csv.loc[subj_id, "PC_center_r":"PC_center_s"].to_numpy().astype(float) + anterior_commisure_center = paths_csv.loc[subj_id, "AC_center_r":"AC_center_s"].to_numpy().astype(float) # adjust LR from label coordinates to orig_up coordinates posterior_commisure_center[0] = 128 @@ -115,68 +108,7 @@ def get_endpoints_from_nib(cc_label_nib, paths_csv, subj_id, return_coordinates= AC_2d = anterior_commisure_center[1:] PC_2d = posterior_commisure_center[1:] - return get_endpoints(cc_mask, AC_2d, PC_2d, resolution=cc_label_nib.header.get_zooms()[1], return_coordinates=return_coordinates) - - -if __name__ == "__main__": - from tqdm import tqdm - OUTPUT_TO_RAS = True - PLOT = False - - paths_csv = pd.read_csv('/groups/ag-reuter-2/users/pollakc/corpus_callosum_fornix/pollakc/network/data/found_labels_with_meta_data_difficult_final.csv', index_col=0) - - for subj_id in tqdm(paths_csv.index): - try: - cc_label_nib = nib.load(paths_csv.loc[subj_id, 'label_merged']) - except Exception as e: - import pdb; pdb.set_trace() - print(subj_id, 'error', e) - continue - - - - # if np.sum(cc_mask) < 20: - # print(subj_id, 'skipping') - # continue - - contour, start_point_A, start_point_P = get_endpoints_from_nib(cc_label_nib, paths_csv, subj_id) - - - - # if PLOT: - # # Add visualization - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots(figsize=(10, 8)) - # ax.imshow(cc_mask, cmap='gray') - # ax.plot(contour[1], contour[0], 'b-', label='Contour') - # # Plot initial endpoint estimates - # ax.plot(start_point_A[1], start_point_A[0], 'rx', - # markersize=8) - # ax.plot(start_point_P[1], start_point_P[0], 'rx', - # markersize=8, label='Ours') - # ax.legend() - # ax.set_title(f'Subject: {subj_id}') - # # Save plot if desired - # #plt.savefig(f'./endpoint_plots/{subj_id}.png', dpi=300, bbox_inches='tight') - # plt.show() - # plt.close() - - - if OUTPUT_TO_RAS: - # use vox2ras matrix to convert to mm - vox2ras_matrix = cc_label_nib.affine - - # Add a third dimension (z) with 0 and a fourth dimension (homogeneous coordinate) with 1 - contour_homogeneous = np.vstack([contour, np.zeros(contour.shape[1]), np.ones(contour.shape[1])]) - start_point_A_homogeneous = np.hstack([start_point_A, [0, 1]]) - start_point_P_homogeneous = np.hstack([start_point_P, [0, 1]]) - - # Apply the transformation - contour = (vox2ras_matrix @ contour_homogeneous)[:3, :] - start_point_A = (vox2ras_matrix @ start_point_A_homogeneous)[:3] - start_point_P = (vox2ras_matrix @ start_point_P_homogeneous)[:3] - - - np.save(f'./contour_data/endpoints_{subj_id}.npy', np.array([start_point_A, start_point_P]), allow_pickle=False) - np.save(f'./contour_data/contours_{subj_id}.npy', np.array(contour), allow_pickle=False) - \ No newline at end of file + return get_endpoints( + cc_mask, AC_2d, PC_2d, resolution=cc_label_nib.header.get_zooms()[1], return_coordinates=return_coordinates + ) + diff --git a/CorpusCallosum/shape/cc_mesh.py b/CorpusCallosum/shape/cc_mesh.py index b575d2c7..1c5e9a26 100644 --- a/CorpusCallosum/shape/cc_mesh.py +++ b/CorpusCallosum/shape/cc_mesh.py @@ -1,5 +1,4 @@ import tempfile -from pathlib import Path import numpy as np import matplotlib @@ -9,6 +8,7 @@ import lapy import pyrr import scipy.interpolate +from scipy.ndimage import gaussian_filter1d from whippersnappy.core import snap1 from shape.cc_thickness import make_mesh_from_contour, HiddenPrints @@ -53,7 +53,13 @@ def __init__(self, num_slices): self.t = None self.original_thickness_vertices = [None] * num_slices - def add_contour(self, slice_idx: int, contour: np.ndarray, thickness_values: np.ndarray, start_end_idx: tuple[int, int] | None = None): + def add_contour( + self, + slice_idx: int, + contour: np.ndarray, + thickness_values: np.ndarray, + start_end_idx: tuple[int, int] | None = None, + ): """Add a contour and its associated thickness values for a specific slice. Args: @@ -69,7 +75,7 @@ def add_contour(self, slice_idx: int, contour: np.ndarray, thickness_values: np. self.original_thickness_vertices[slice_idx] = np.where(~np.isnan(thickness_values))[0] if start_end_idx is None: - self.start_end_idx[slice_idx] = (0, len(contour)//2) + self.start_end_idx[slice_idx] = (0, len(contour) // 2) else: self.start_end_idx[slice_idx] = start_end_idx @@ -83,8 +89,6 @@ def set_acpc_coords(self, ac_coords: np.ndarray, pc_coords: np.ndarray): self.ac_coords = ac_coords self.pc_coords = pc_coords - - def set_resolution(self, resolution: float): """Set the spatial resolution of the mesh. @@ -93,21 +97,24 @@ def set_resolution(self, resolution: float): """ self.resolution = resolution - def plot_mesh(self, output_path: str | None = None, - colormap: str = "red_to_yellow", - thickness_overlay: bool = True, - show_contours: bool = False, - show_grid: bool = False, - color_range: tuple[float, float] | None = None, - show_mesh_edges: bool = False, - legend: str = "", - threshold: tuple[float, float] | None = None): + def plot_mesh( + self, + output_path: str | None = None, + colormap: str = "red_to_yellow", + thickness_overlay: bool = True, + show_contours: bool = False, + show_grid: bool = False, + color_range: tuple[float, float] | None = None, + show_mesh_edges: bool = False, + legend: str = "", + threshold: tuple[float, float] | None = None, + ): """Plot the mesh using Plotly for better performance and interactivity. - + Creates an interactive 3D visualization of the mesh with optional features like thickness overlay, contour display, and grid visualization. The plot can be saved to an HTML file or displayed in a web browser. - + Args: output_path (str, optional): Path to save the plot. If None, displays the plot interactively. colormap (str, optional): Which colormap to use. Options are: @@ -128,83 +135,80 @@ def plot_mesh(self, output_path: str | None = None, Defaults to (-0.2, 0.2). """ assert self.v is not None and self.t is not None, "Mesh has not been created yet" - + if len(self.v) == 0: print("Warning: No vertices in mesh to plot") return - + if len(self.t) == 0: print("Warning: No faces in mesh to plot") return - + # Define available colormaps colormaps = { "red_to_blue": [ - [0.0, "rgb(255,0,0)"], # Bright red - [0.25, "rgb(255,165,0)"], # Light orange - [0.5, "rgb(150,150,150)"], # Dark grey for middle + [0.0, "rgb(255,0,0)"], # Bright red + [0.25, "rgb(255,165,0)"], # Light orange + [0.5, "rgb(150,150,150)"], # Dark grey for middle [0.75, "rgb(173,216,230)"], # Light blue - [1.0, "rgb(0,0,255)"] # Bright blue + [1.0, "rgb(0,0,255)"], # Bright blue ], "blue_to_red": [ - [0.0, "rgb(0,0,255)"], # Bright blue + [0.0, "rgb(0,0,255)"], # Bright blue [0.25, "rgb(173,216,230)"], # Light blue - [0.5, "rgb(150,150,150)"], # Dark grey for middle - [0.75, "rgb(255,165,0)"], # Light orange - [1.0, "rgb(255,0,0)"] # Bright red + [0.5, "rgb(150,150,150)"], # Dark grey for middle + [0.75, "rgb(255,165,0)"], # Light orange + [1.0, "rgb(255,0,0)"], # Bright red ], "red_to_yellow": [ - [0.0, "rgb(255,0,0)"], # Bright red - [0.33, "rgb(255,85,0)"], # Red-orange - [0.66, "rgb(255,170,0)"], # Orange - [1.0, "rgb(255,255,0)"] # Yellow + [0.0, "rgb(255,0,0)"], # Bright red + [0.33, "rgb(255,85,0)"], # Red-orange + [0.66, "rgb(255,170,0)"], # Orange + [1.0, "rgb(255,255,0)"], # Yellow ], "yellow_to_red": [ - [0.0, "rgb(255,255,0)"], # Yellow - [0.33, "rgb(255,170,0)"], # Orange - [0.66, "rgb(255,85,0)"], # Red-orange - [1.0, "rgb(255,0,0)"] # Bright red - ] + [0.0, "rgb(255,255,0)"], # Yellow + [0.33, "rgb(255,170,0)"], # Orange + [0.66, "rgb(255,85,0)"], # Red-orange + [1.0, "rgb(255,0,0)"], # Bright red + ], } - + # Select the colormap if colormap not in colormaps: print(f"Warning: Unknown colormap '{colormap}'. Using 'red_to_blue' instead.") colormap = "red_to_blue" - + selected_colormap = colormaps[colormap] # If threshold is provided, modify the colormap to include grey region - if threshold is not None and thickness_overlay and hasattr(self, 'mesh_vertex_colors'): + if threshold is not None and thickness_overlay and hasattr(self, "mesh_vertex_colors"): data_min = np.min(self.mesh_vertex_colors) if color_range is None else color_range[0] data_max = np.max(self.mesh_vertex_colors) if color_range is None else color_range[1] data_range = data_max - data_min - + # Calculate normalized threshold positions thresh_low = (threshold[0] - data_min) / data_range thresh_high = (threshold[1] - data_min) / data_range - + # Ensure thresholds are within [0,1] thresh_low = max(0, min(1, thresh_low)) thresh_high = max(0, min(1, thresh_high)) - + # Create new colormap with grey threshold region grey_color = "rgb(150,150,150)" # Medium grey new_colormap = [] - + # Add colors before threshold with adjusted positions if thresh_low > 0: for pos, color in selected_colormap: if pos < 1: # Only use positions less than 1 new_pos = pos * thresh_low new_colormap.append([new_pos, color]) - + # Add threshold boundaries with grey - new_colormap.extend([ - [thresh_low, grey_color], - [thresh_high, grey_color] - ]) - + new_colormap.extend([[thresh_low, grey_color], [thresh_high, grey_color]]) + # Add colors after threshold with adjusted positions if thresh_high < 1: remaining_range = 1 - thresh_high @@ -213,7 +217,7 @@ def plot_mesh(self, output_path: str | None = None, new_pos = thresh_high + pos * remaining_range if new_pos <= 1: # Ensure we don't exceed 1 new_colormap.append([new_pos, color]) - + selected_colormap = new_colormap # Calculate data ranges and center @@ -228,120 +232,129 @@ def plot_mesh(self, output_path: str | None = None, # Add the mesh as a surface mesh_args = { - 'x': self.v[:, 0], - 'y': self.v[:, 1], - 'z': self.v[:, 2], - 'i': self.t[:, 0], # First vertex of each triangle - 'j': self.t[:, 1], # Second vertex - 'k': self.t[:, 2], # Third vertex - 'hoverinfo': 'skip', - 'lighting': dict(ambient=0.9, diffuse=0.1, roughness=0.3) + "x": self.v[:, 0], + "y": self.v[:, 1], + "z": self.v[:, 2], + "i": self.t[:, 0], # First vertex of each triangle + "j": self.t[:, 1], # Second vertex + "k": self.t[:, 2], # Third vertex + "hoverinfo": "skip", + "lighting": dict(ambient=0.9, diffuse=0.1, roughness=0.3), } - if thickness_overlay and hasattr(self, 'mesh_vertex_colors'): - mesh_args.update({ - 'intensity': self.mesh_vertex_colors, # Add intensity values for colorbar - 'showscale': True, - 'colorbar': dict( - title=dict( - text=legend, - font=dict(size=35, color='white'), # Increase title font size and make white - side='right' # Place title on right side + if thickness_overlay and hasattr(self, "mesh_vertex_colors"): + mesh_args.update( + { + "intensity": self.mesh_vertex_colors, # Add intensity values for colorbar + "showscale": True, + "colorbar": dict( + title=dict( + text=legend, + font=dict(size=35, color="white"), # Increase title font size and make white + side="right", # Place title on right side + ), + len=0.55, # Make colorbar shorter + thickness=35, # Make colorbar wider + tickfont=dict(size=30, color="white"), # Increase tick font size and make white + tickformat=".1f", # Show one decimal place ), - len=0.55, # Make colorbar shorter - thickness=35, # Make colorbar wider - tickfont=dict(size=30, color='white'), # Increase tick font size and make white - tickformat='.1f', # Show one decimal place - ), - 'opacity': 1, - 'colorscale': selected_colormap - }) - + "opacity": 1, + "colorscale": selected_colormap, + } + ) + # Set the colorbar range if color_range is not None: - mesh_args['cmin'] = color_range[0] - mesh_args['cmax'] = color_range[1] + mesh_args["cmin"] = color_range[0] + mesh_args["cmax"] = color_range[1] else: # Use data range if no explicit range is provided - mesh_args['cmin'] = np.min(self.mesh_vertex_colors) - mesh_args['cmax'] = np.max(self.mesh_vertex_colors) + mesh_args["cmin"] = np.min(self.mesh_vertex_colors) + mesh_args["cmax"] = np.max(self.mesh_vertex_colors) else: - mesh_args['color'] = 'lightsteelblue' + mesh_args["color"] = "lightsteelblue" fig.add_trace(go.Mesh3d(**mesh_args)) if show_contours: # Add contour polylines for reference num_slices = len(self.contours) - + # Calculate z coordinates for each slice - use same calculation as in create_mesh - lr_center = self.v[len(self.v)//2][2] + lr_center = self.v[len(self.v) // 2][2] z_coordinates = np.arange(num_slices) * self.resolution - (num_slices // 2) * self.resolution + lr_center - + for i in range(num_slices): if self.contours[i] is not None: # Use slice position for z coordinate z_coord = z_coordinates[i] contour = self.contours[i] - + # Create 3D points with fixed z coordinate v_i = np.hstack([contour, np.full((len(contour), 1), z_coord)]) - + # Close the contour by adding the first point at the end v_i = np.vstack([v_i, v_i[0]]) - - fig.add_trace(go.Scatter3d( - x=v_i[:, 0], - y=v_i[:, 1], - z=v_i[:, 2], - mode='lines', - line=dict(color='white', width=2), - opacity=0.5, - hoverinfo='skip', - showlegend=False - )) - if show_mesh_edges: # show the mesh edges - edge_color = 'darkgray' + + fig.add_trace( + go.Scatter3d( + x=v_i[:, 0], + y=v_i[:, 1], + z=v_i[:, 2], + mode="lines", + line=dict(color="white", width=2), + opacity=0.5, + hoverinfo="skip", + showlegend=False, + ) + ) + if show_mesh_edges: # show the mesh edges + edge_color = "darkgray" vertices_in_first_contour = len(self.contours[0]) - vertices_to_plot_first = np.concatenate([self.v[:vertices_in_first_contour], self.v[None,0]]) + vertices_to_plot_first = np.concatenate([self.v[:vertices_in_first_contour], self.v[None, 0]]) # Add mesh edges for first 900 vertices as one continuous line - fig.add_trace(go.Scatter3d( - x=vertices_to_plot_first[:,0], - y=vertices_to_plot_first[:,1], - z=vertices_to_plot_first[:,2], - mode='lines', - line=dict(color=edge_color, width=8), - opacity=1, - hoverinfo='skip', - showlegend=False - )) + fig.add_trace( + go.Scatter3d( + x=vertices_to_plot_first[:, 0], + y=vertices_to_plot_first[:, 1], + z=vertices_to_plot_first[:, 2], + mode="lines", + line=dict(color=edge_color, width=8), + opacity=1, + hoverinfo="skip", + showlegend=False, + ) + ) vertices_in_last_contour = len(self.contours[-1]) vertices_before_last_contour = np.sum([len(c) for c in self.contours[:-1]]) - vertices_to_plot_last = np.concatenate([self.v[vertices_before_last_contour:vertices_before_last_contour + vertices_in_last_contour], self.v[None,vertices_before_last_contour]]) - fig.add_trace(go.Scatter3d( - x=vertices_to_plot_last[:,0], - y=vertices_to_plot_last[:,1], - z=vertices_to_plot_last[:,2], - mode='lines', - line=dict(color=edge_color, width=8), - opacity=1, - hoverinfo='skip', - showlegend=False - )) - + vertices_to_plot_last = np.concatenate( + [ + self.v[vertices_before_last_contour : vertices_before_last_contour + vertices_in_last_contour], + self.v[None, vertices_before_last_contour], + ] + ) + fig.add_trace( + go.Scatter3d( + x=vertices_to_plot_last[:, 0], + y=vertices_to_plot_last[:, 1], + z=vertices_to_plot_last[:, 2], + mode="lines", + line=dict(color=edge_color, width=8), + opacity=1, + hoverinfo="skip", + showlegend=False, + ) + ) # Calculate axis ranges to maintain equal aspect ratio ranges = [] for i in range(3): - axis_range = [ - center[i] - max_range/2, - center[i] + max_range/2 - ] + axis_range = [center[i] - max_range / 2, center[i] + max_range / 2] ranges.append(axis_range) - + # Configure axes and grid visibility axis_config = dict( showgrid=show_grid, @@ -349,29 +362,26 @@ def plot_mesh(self, output_path: str | None = None, zeroline=show_grid, showbackground=show_grid, showticklabels=show_grid, - gridcolor='white', - tickfont=dict(color='white'), - title=dict(font=dict(color='white')) + gridcolor="white", + tickfont=dict(color="white"), + title=dict(font=dict(color="white")), ) - + fig.update_layout( scene=dict( - xaxis=dict(range=ranges[0], **{**axis_config, 'title': 'AP' if show_grid else ''}), - yaxis=dict(range=ranges[1], **{**axis_config, 'title': 'SI' if show_grid else ''}), - zaxis=dict(range=ranges[2], **{**axis_config, 'title': 'LR' if show_grid else ''}), - camera=dict( - eye=dict(x=1.5, y=1.5, z=1), - up=dict(x=0, y=0, z=1) - ), - aspectmode='cube', # Force equal aspect ratio + xaxis=dict(range=ranges[0], **{**axis_config, "title": "AP" if show_grid else ""}), + yaxis=dict(range=ranges[1], **{**axis_config, "title": "SI" if show_grid else ""}), + zaxis=dict(range=ranges[2], **{**axis_config, "title": "LR" if show_grid else ""}), + camera=dict(eye=dict(x=1.5, y=1.5, z=1), up=dict(x=0, y=0, z=1)), + aspectmode="cube", # Force equal aspect ratio aspectratio=dict(x=1, y=1, z=1), - bgcolor='black', - dragmode='orbit' # Enable orbital rotation by default + bgcolor="black", + dragmode="orbit", # Enable orbital rotation by default ), showlegend=False, margin=dict(l=0, r=100, t=0, b=0), # Increased right margin for colorbar - paper_bgcolor='black', - plot_bgcolor='black' + paper_bgcolor="black", + plot_bgcolor="black", ) if output_path is not None: @@ -381,11 +391,10 @@ def plot_mesh(self, output_path: str | None = None, import tempfile import webbrowser import os - - temp_path = os.path.join(tempfile.gettempdir(), 'cc_mesh_plot.html') - fig.write_html(temp_path) - webbrowser.open('file://' + temp_path) + temp_path = os.path.join(tempfile.gettempdir(), "cc_mesh_plot.html") + fig.write_html(temp_path) + webbrowser.open("file://" + temp_path) def get_contour_edge_lengths(self, contour_idx): """Get the lengths of the edges of a contour. @@ -398,7 +407,6 @@ def get_contour_edge_lengths(self, contour_idx): """ edges = np.diff(self.contours[contour_idx], axis=0) return np.sqrt(np.sum(edges**2, axis=1)) - @staticmethod def make_triangles_between_contours(contour1, contour2): @@ -422,25 +430,23 @@ def make_triangles_between_contours(contour1, contour2): triangles = [] n1 = len(contour1) n2 = len(contour2) - + for i in range(n1): # Current and next indices for contour1 c1_curr = (start_idx_c1 + i) % n1 c1_next = (start_idx_c1 + i + 1) % n1 - + # Current and next indices for contour2, offset by n1 to account for vertex stacking c2_curr = ((start_idx_c2 + i) % n2) + n1 c2_next = ((start_idx_c2 + i + 1) % n2) + n1 - + # Create two triangles to form a quad between the contours triangles.append([c1_curr, c2_curr, c1_next]) triangles.append([c2_curr, c2_next, c1_next]) return np.array(triangles) - def _create_levelpaths(self, contour_idx, points, trias, num_points=None): - # # compute poisson with HiddenPrints(): cc_tria = lapy.TriaMesh(points, trias) @@ -448,40 +454,35 @@ def _create_levelpaths(self, contour_idx, points, trias, num_points=None): bdr = np.array(cc_tria.boundary_loops()[0]) # find index of endpoints in bdr list - iidx1=np.where(bdr==self.start_end_idx[contour_idx][0])[0][0] - iidx2=np.where(bdr==self.start_end_idx[contour_idx][1])[0][0] + iidx1 = np.where(bdr == self.start_end_idx[contour_idx][0])[0][0] + iidx2 = np.where(bdr == self.start_end_idx[contour_idx][1])[0][0] # create boundary condition (0 at endpoints, -1 on one side, 1 on the other): if iidx1 > iidx2: - tmp= iidx2 + tmp = iidx2 iidx2 = iidx1 iidx1 = tmp dcond = np.ones(bdr.shape) - dcond[iidx1] =0 - dcond[iidx2] =0 - dcond[iidx1+1:iidx2] = -1 - + dcond[iidx1] = 0 + dcond[iidx2] = 0 + dcond[iidx1 + 1 : iidx2] = -1 # Extract path with HiddenPrints(): fem = lapy.Solver(cc_tria) - vfunc = fem.poisson(0,(bdr,dcond)) + vfunc = fem.poisson(0, (bdr, dcond)) if num_points is not None: # TODO: do midline stuff level = 0 - midline_equidistant, midline_length = cc_tria.level_path(vfunc, level, n_points=num_points+2) - midline_equidistant = midline_equidistant[:,:2] + midline_equidistant, midline_length = cc_tria.level_path(vfunc, level, n_points=num_points + 2) + midline_equidistant = midline_equidistant[:, :2] eval_points = midline_equidistant else: eval_points = self.contours[contour_idx] - gf = lapy.diffgeo.compute_rotated_f(cc_tria,vfunc) - - - - + gf = lapy.diffgeo.compute_rotated_f(cc_tria, vfunc) # interpolate midline to get levels to evaluate - gf_interp = scipy.interpolate.griddata(cc_tria.v[:,0:2], gf, eval_points, method='nearest') + gf_interp = scipy.interpolate.griddata(cc_tria.v[:, 0:2], gf, eval_points, method="nearest") # sort by value sorting_idx_gf = np.argsort(gf_interp) @@ -489,25 +490,24 @@ def _create_levelpaths(self, contour_idx, points, trias, num_points=None): sorted_thickness_values = self.thickness_values[contour_idx][sorting_idx_gf] # get levels to evaluate - #level_length = tria.level_length(gf, gf_interp) + # level_length = tria.level_length(gf, gf_interp) levelpaths = [] thickness_values = [] - for i in range(0,len(eval_points)): - level = gf_interp[i] + for i in range(0, len(eval_points)): + level = gf_interp[i] # levelpath starts at index zero if level == 0: continue lvlpath, lvlpath_length, tria_idx = cc_tria.level_path(gf, level, get_tria_idx=True) - + levelpaths.append(lvlpath) thickness_values.append(sorted_thickness_values[i]) - + return levelpaths, thickness_values def _create_cap(self, points, trias, contour_idx): - levelpaths, thickness_values = self._create_levelpaths(contour_idx, points, trias) # Create mesh from level paths @@ -519,7 +519,8 @@ def _create_cap(self, points, trias, contour_idx): # smooth thickness values from scipy.ndimage import gaussian_filter1d - for i in range(3): + + for _ in range(3): sorted_thickness_values = gaussian_filter1d(sorted_thickness_values, sigma=5) NUM_LEVELPOINTS = 50 @@ -528,7 +529,6 @@ def _create_cap(self, points, trias, contour_idx): # TODO: handle gap between first/last levelpath and contour for idx, levelpath1 in enumerate(levelpaths): - levelpath1 = lapy.TriaMesh._TriaMesh__iterative_resample_polygon(levelpath1, NUM_LEVELPOINTS) level_vertices.append(levelpath1) level_colors.append(np.full((len(levelpath1)), sorted_thickness_values[idx])) @@ -538,33 +538,33 @@ def _create_cap(self, points, trias, contour_idx): # Create faces between the two paths by connecting vertices faces_between = [] i, j = 0, 0 - - while i < len(levelpath1)-1 and j < len(levelpath2)-1: - faces_between.append([i, i+1, len(levelpath1)+j]) - faces_between.append([i+1, len(levelpath1)+j+1, len(levelpath1)+j]) - + + while i < len(levelpath1) - 1 and j < len(levelpath2) - 1: + faces_between.append([i, i + 1, len(levelpath1) + j]) + faces_between.append([i + 1, len(levelpath1) + j + 1, len(levelpath1) + j]) + i += 1 j += 1 - - while i < len(levelpath1)-1: - faces_between.append([i, i+1, len(levelpath1)+j]) + + while i < len(levelpath1) - 1: + faces_between.append([i, i + 1, len(levelpath1) + j]) i += 1 - - while j < len(levelpath2)-1: - faces_between.append([i, len(levelpath1)+j+1, len(levelpath1)+j]) + + while j < len(levelpath2) - 1: + faces_between.append([i, len(levelpath1) + j + 1, len(levelpath1) + j]) j += 1 - + if faces_between: faces_between = np.array(faces_between) level_faces.append(faces_between + vertex_counter) - vertex_counter += len(levelpath1) + vertex_counter += len(levelpath1) # Convert to numpy arrays level_vertices = np.vstack(level_vertices) level_faces = np.vstack(level_faces) level_colors = np.concatenate(level_colors) - + return level_vertices, level_faces, level_colors def create_mesh(self, lr_center: float = 0, closed: bool = False, smooth: int = 0): @@ -586,17 +586,19 @@ def create_mesh(self, lr_center: float = 0, closed: bool = False, smooth: int = self.v = np.array([]) self.t = np.array([]) return - + # Calculate z coordinates for each slice - z_coordinates = np.arange(len(valid_contours)) * self.resolution - (len(valid_contours) // 2) * self.resolution + lr_center - + z_coordinates = ( + np.arange(len(valid_contours)) * self.resolution - (len(valid_contours) // 2) * self.resolution + lr_center + ) + # Build vertices list with z-coordinates vertices = [] faces = [] vertex_start_indices = [] # Track starting index for each contour current_index = 0 - - for i, (idx, contour) in enumerate(valid_contours): + + for i, (_, contour) in enumerate(valid_contours): vertex_start_indices.append(current_index) vertices.append(np.hstack([contour, np.full((len(contour), 1), z_coordinates[i])])) @@ -608,32 +610,32 @@ def create_mesh(self, lr_center: float = 0, closed: bool = False, smooth: int = current_index += len(contour) - - self.set_mesh(vertices, faces, self.thickness_values) if smooth > 0: self.smooth_(smooth) - if closed: # Close the mesh by creating caps on both ends # Left cap (first slice) - use counterclockwise orientation - left_side_points, left_side_trias = make_mesh_from_contour(self.v[:vertex_start_indices[1]][..., :2]) + left_side_points, left_side_trias = make_mesh_from_contour(self.v[: vertex_start_indices[1]][..., :2]) left_side_points = np.hstack([left_side_points, np.full((len(left_side_points), 1), z_coordinates[0])]) # Right cap (last slice) - reverse points for proper orientation - right_side_points, right_side_trias = make_mesh_from_contour(self.v[vertex_start_indices[-1]:][..., :2]) + right_side_points, right_side_trias = make_mesh_from_contour(self.v[vertex_start_indices[-1] :][..., :2]) right_side_points = np.hstack([right_side_points, np.full((len(right_side_points), 1), z_coordinates[-1])]) color_sides = True if color_sides: - left_side_points, left_side_trias, left_side_colors = self._create_cap(left_side_points, left_side_trias, 0) - right_side_points, right_side_trias, right_side_colors = self._create_cap(right_side_points, right_side_trias, len(self.contours) - 1) + left_side_points, left_side_trias, left_side_colors = self._create_cap( + left_side_points, left_side_trias, 0 + ) + right_side_points, right_side_trias, right_side_colors = self._create_cap( + right_side_points, right_side_trias, len(self.contours) - 1 + ) # reverse right side trias - right_side_trias = right_side_trias[:,::-1] - + right_side_trias = right_side_trias[:, ::-1] left_side_trias = left_side_trias + current_index current_index += len(left_side_points) @@ -641,8 +643,11 @@ def create_mesh(self, lr_center: float = 0, closed: bool = False, smooth: int = right_side_trias = right_side_trias + current_index current_index += len(right_side_points) - self.set_mesh([self.v, left_side_points, right_side_points], [self.t, left_side_trias, right_side_trias], [self.mesh_vertex_colors, left_side_colors, right_side_colors]) - + self.set_mesh( + [self.v, left_side_points, right_side_points], + [self.t, left_side_trias, right_side_trias], + [self.mesh_vertex_colors, left_side_colors, right_side_colors], + ) def fill_thickness_values(self): """ @@ -687,17 +692,15 @@ def fill_thickness_values(self): self.thickness_values[i] = thickness - def smooth_thickness_values(self, iterations: int = 1): """ Smooth the thickness values using a Gaussian filter """ - from scipy.ndimage import gaussian_filter1d + for i in range(len(self.thickness_values)): if self.thickness_values[i] is not None: self.thickness_values[i] = gaussian_filter1d(self.thickness_values[i], sigma=5) - def plot_contour(self, slice_idx: int, output_path: str): """Plot a single contour with thickness values. @@ -713,71 +716,74 @@ def plot_contour(self, slice_idx: int, output_path: str): """ if self.contours[slice_idx] is None: - raise ValueError(f'Contour for slice {slice_idx} is not set') + raise ValueError(f"Contour for slice {slice_idx} is not set") contour = self.contours[slice_idx] plt.figure(figsize=(15, 10)) # Get thickness values for this slice thickness = self.thickness_values[slice_idx] - + # Plot points with colors based on thickness for i in range(len(contour)): if np.isnan(thickness[i]): - plt.plot(contour[i,0], contour[i,1], 'o', color='gray', markersize=1) + plt.plot(contour[i, 0], contour[i, 1], "o", color="gray", markersize=1) else: # Map thickness to color from red to yellow - plt.plot(contour[i,0], contour[i,1], 'o', color=plt.cm.YlOrRd(thickness[i]/np.nanmax(thickness)), markersize=1) - + plt.plot( + contour[i, 0], + contour[i, 1], + "o", + color=plt.cm.YlOrRd(thickness[i] / np.nanmax(thickness)), + markersize=1, + ) + # Connect points with lines - plt.plot(contour[:,0], contour[:,1], '-', color='black', alpha=0.3, label='Contour') - plt.axis('equal') - plt.xlabel('X') - plt.ylabel('Y') - plt.title(f'CC contour for slice {slice_idx}') + plt.plot(contour[:, 0], contour[:, 1], "-", color="black", alpha=0.3, label="Contour") + plt.axis("equal") + plt.xlabel("X") + plt.ylabel("Y") + plt.title(f"CC contour for slice {slice_idx}") plt.legend() plt.grid(True) plt.tight_layout() plt.savefig(output_path, dpi=300) - - def smooth_contour(self, contour_idx, window_size=5): """ Smooth a contour using a moving average filter - + Parameters: ----------- contour : tuple of arrays The contour coordinates (x, y) window_size : int Size of the smoothing window - + Returns: -------- tuple of arrays The smoothed contour coordinates (x, y) """ x, y = self.contours[contour_idx].T - + # Ensure the window size is odd if window_size % 2 == 0: window_size += 1 - + # Create a padded version of the arrays to handle the edges - x_padded = np.pad(x, (window_size//2, window_size//2), mode='wrap') - y_padded = np.pad(y, (window_size//2, window_size//2), mode='wrap') - + x_padded = np.pad(x, (window_size // 2, window_size // 2), mode="wrap") + y_padded = np.pad(y, (window_size // 2, window_size // 2), mode="wrap") + # Apply moving average x_smoothed = np.zeros_like(x) y_smoothed = np.zeros_like(y) - + for i in range(len(x)): - x_smoothed[i] = np.mean(x_padded[i:i+window_size]) - y_smoothed[i] = np.mean(y_padded[i:i+window_size]) - - self.contours[contour_idx] = np.array([x_smoothed, y_smoothed]).T + x_smoothed[i] = np.mean(x_padded[i : i + window_size]) + y_smoothed[i] = np.mean(y_padded[i : i + window_size]) + self.contours[contour_idx] = np.array([x_smoothed, y_smoothed]).T def plot_cc_contour_with_levelsets(self, contour_idx=0, levelpaths=None, title=None, save_path=None, colorbar=True): """Plot a contour with levelset visualization. @@ -796,9 +802,10 @@ def plot_cc_contour_with_levelsets(self, contour_idx=0, levelpaths=None, title=N matplotlib.figure.Figure: The created figure object. """ - plot_values = np.array(self.thickness_values[contour_idx][~np.isnan(self.thickness_values[contour_idx])][:100])[::-1] + plot_values = np.array(self.thickness_values[contour_idx][~np.isnan(self.thickness_values[contour_idx])][:100])[ + ::-1 + ] # double plot values with linear interpolation - # Create bar plot of thickness values # fig, ax = plt.subplots(figsize=(10, 4)) @@ -810,7 +817,7 @@ def plot_cc_contour_with_levelsets(self, contour_idx=0, levelpaths=None, title=N # ax.invert_xaxis() # plt.tight_layout() # plt.show() - + points, trias = make_mesh_from_contour(self.contours[contour_idx], max_volume=0.5, min_angle=25, verbose=False) # make points 3D by adding zero @@ -826,8 +833,7 @@ def plot_cc_contour_with_levelsets(self, contour_idx=0, levelpaths=None, title=N margin = 1 resolution = 0.05 # Higher resolution for smoother interpolation x_grid, y_grid = np.meshgrid( - np.arange(x_min - margin, x_max + margin, resolution), - np.arange(y_min - margin, y_max + margin, resolution) + np.arange(x_min - margin, x_max + margin, resolution), np.arange(y_min - margin, y_max + margin, resolution) ) # Create a path from the outside contour @@ -844,33 +850,29 @@ def plot_cc_contour_with_levelsets(self, contour_idx=0, levelpaths=None, title=N all_level_values = [] for i, path in enumerate(levelpaths): - if len(path) == 1: all_level_points_x.append(path[0][0]) all_level_points_y.append(path[0][1]) all_level_values.append(plot_values[i]) continue - - # make levelpath - path = lapy.TriaMesh._TriaMesh__resample_polygon(path, 1000) - + # make levelpath + path = lapy.TriaMesh._TriaMesh__resample_polygon(path, 1000) # Extend at the beginning: add point in direction opposite to first segment first_segment = path[1] - path[0] # standardize length of first segment first_segment = first_segment / np.linalg.norm(first_segment) * 10 - extension_start = path[0] - first_segment + extension_start = path[0] - first_segment all_level_points_x.append(extension_start[0]) all_level_points_y.append(extension_start[1]) all_level_values.append(plot_values[i]) - + # Add original path points for point in path: all_level_points_x.append(point[0]) all_level_points_y.append(point[1]) all_level_values.append(plot_values[i]) - # Extend at the end: add point in direction of last segment last_segment = path[-1] - path[-2] @@ -889,11 +891,7 @@ def plot_cc_contour_with_levelsets(self, contour_idx=0, levelpaths=None, title=N # Use griddata to perform smooth interpolation - using 'linear' instead of 'cubic' # and properly formatting the input points grid_values = scipy.interpolate.griddata( - (all_level_points_x, all_level_points_y), - all_level_values, - (x_grid, y_grid), - method='linear', - fill_value=0 + (all_level_points_x, all_level_points_y), all_level_values, (x_grid, y_grid), method="linear", fill_value=0 ) # smooth the grid_values @@ -902,7 +900,6 @@ def plot_cc_contour_with_levelsets(self, contour_idx=0, levelpaths=None, title=N # Apply the mask to only show values inside the contour masked_values = np.where(mask, grid_values, np.nan) - # Sample colormaps (e.g., 'binary' and 'gist_heat_r') colors1 = plt.cm.binary([0.4] * 128) colors2 = plt.cm.hot(np.linspace(0.8, 0.1, 128)) @@ -911,56 +908,67 @@ def plot_cc_contour_with_levelsets(self, contour_idx=0, levelpaths=None, title=N colors = np.vstack((colors2, colors1)) # Create a new colormap - cmap = matplotlib.colors.LinearSegmentedColormap.from_list('my_colormap', colors) + cmap = matplotlib.colors.LinearSegmentedColormap.from_list("my_colormap", colors) - - # Plot CC contour with levelsets - fig = plt.figure(figsize=(10,3)) + fig = plt.figure(figsize=(10, 3)) # Apply a 10-degree rotation to the entire plot base = plt.gca().transData transform = matplotlib.transforms.Affine2D().rotate_deg(10) transform = transform + base # Plot the filled contour with interpolated colors - plt.imshow(masked_values, extent=[x_min-margin, x_max+margin, y_min-margin, y_max+margin], - origin='lower', cmap=cmap, alpha=1, interpolation='bilinear', vmin=0, vmax=0.10, transform=transform) - - plt.imshow(masked_values, - extent=[x_min-margin, x_max+margin, y_min-margin, y_max+margin], - origin='lower', cmap=cmap, alpha=1, interpolation='bilinear', - vmin=0, vmax=0.10, - #norm=LogNorm(vmin=1e-3, vmax=0.1), # Set minimum to avoid log(0) - transform=transform) + plt.imshow( + masked_values, + extent=[x_min - margin, x_max + margin, y_min - margin, y_max + margin], + origin="lower", + cmap=cmap, + alpha=1, + interpolation="bilinear", + vmin=0, + vmax=0.10, + transform=transform, + ) + + plt.imshow( + masked_values, + extent=[x_min - margin, x_max + margin, y_min - margin, y_max + margin], + origin="lower", + cmap=cmap, + alpha=1, + interpolation="bilinear", + vmin=0, + vmax=0.10, + # norm=LogNorm(vmin=1e-3, vmax=0.1), # Set minimum to avoid log(0) + transform=transform, + ) if colorbar: # Add a colorbar cbar = plt.colorbar(aspect=10) cbar.ax.set_ylim(0.001, 0.054) cbar.ax.set_yticks([0.0, 0.01, 0.02, 0.03, 0.04, 0.05]) - #cbar.ax.set_yticks([0.001, 0.01, 0.05]) - #cbar.ax.set_yticklabels(['0.001', '0.01', '0.05']) - cbar.set_label('p-value (log scale)') + # cbar.ax.set_yticks([0.001, 0.01, 0.05]) + # cbar.ax.set_yticklabels(['0.001', '0.01', '0.05']) + cbar.set_label("p-value (log scale)") # Plot the outside contour on top for clear boundary - plt.plot(outside_contour[0], outside_contour[1], 'k-', linewidth=2, label='CC Contour', transform=transform) - + plt.plot(outside_contour[0], outside_contour[1], "k-", linewidth=2, label="CC Contour", transform=transform) # plot levelpaths - #for i, path in enumerate(levelpaths): + # for i, path in enumerate(levelpaths): # plt.plot(path[:,0], path[:,1], 'k--', linewidth=1, alpha=0.2, transform=transform) # plot midline # if midline_equidistant is not None: # midline_x, midline_y = zip(*midline_equidistant) # plt.plot(midline_x, midline_y, 'k--', linewidth=2, transform=transform, alpha=0.2) - - plt.axis('equal') - plt.title(title, fontsize=14, fontweight='bold') - #plt.legend(loc='best') + plt.axis("equal") + plt.title(title, fontsize=14, fontweight="bold") + # plt.legend(loc='best') plt.gca().invert_xaxis() - plt.axis('off') - #plt.tight_layout() + plt.axis("off") + # plt.tight_layout() # plt.ylim(-105, -75) # plt.xlim(181, 101) if save_path is not None: @@ -968,8 +976,6 @@ def plot_cc_contour_with_levelsets(self, contour_idx=0, levelpaths=None, title=N plt.show() return fig - - def set_mesh(self, vertices, faces, thickness_values=None): """Set the mesh vertices, faces, and optional thickness values. @@ -989,7 +995,7 @@ def set_mesh(self, vertices, faces, thickness_values=None): # Skip parent initialization since we have no faces else: super().__init__(np.vstack(vertices), np.vstack(faces)) - + if thickness_values is not None: # Filter out empty thickness arrays and concatenate valid_thickness = [tv for tv in thickness_values if tv is not None and len(tv) > 0] @@ -1003,13 +1009,11 @@ def __create_cc_viewmat(): """ Create the view matrix for a nice view of the corpus callosum. """ - viewLeft = np.array([[ 0, 0,-1, 0], [-1, 0, 0, 0], [ 0, 1, 0, 0], [ 0, 0, 0, 1]]) # left w top up // right + viewLeft = np.array([[0, 0, -1, 0], [-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) # left w top up // right transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) viewmat = transl * viewLeft - - - #rotate 10 degrees around x axis + # rotate 10 degrees around x axis rot = pyrr.Matrix44.from_x_rotation(np.deg2rad(-10)) viewmat = viewmat * rot @@ -1040,27 +1044,27 @@ def snap_cc_picture(self, output_path: str): if len(self.t) == 0: print("Warning: Cannot create snapshot - no faces in mesh") return - + # create temp file - temp_file = tempfile.NamedTemporaryFile(suffix='.fssurf', delete=True) + temp_file = tempfile.NamedTemporaryFile(suffix=".fssurf", delete=True) self.write_fssurf(temp_file.name) # Write thickness values as overlay - if hasattr(self, 'mesh_vertex_colors'): - overlay_file = tempfile.NamedTemporaryFile(suffix='.w', delete=True) + if hasattr(self, "mesh_vertex_colors"): + overlay_file = tempfile.NamedTemporaryFile(suffix=".w", delete=True) # Write thickness values in FreeSurfer .w format nib.freesurfer.write_morph_data(overlay_file.name, self.mesh_vertex_colors) overlaypath = overlay_file.name else: overlaypath = None - + snap1( temp_file.name, overlaypath=overlaypath, view=None, viewmat=self.__create_cc_viewmat(), - width=3*500, - height=3*300, + width=3 * 500, + height=3 * 300, outpath=output_path, ambient=0.6, colorbar_scale=0.5, @@ -1068,10 +1072,10 @@ def snap_cc_picture(self, output_path: str): colorbar_x=0.19, brain_scale=2.1, fthresh=0, - caption='Corpus Callosum thickness (mm)', + caption="Corpus Callosum thickness (mm)", caption_y=0.85, caption_x=0.17, - caption_scale=0.5 + caption_scale=0.5, ) temp_file.close() @@ -1090,7 +1094,6 @@ def smooth_(self, iterations: int = 1): super().smooth_(iterations) self.v[:, 2] = z_values - def save_contours(self, output_path: str): """Save the contours to a CSV file. @@ -1103,13 +1106,16 @@ def save_contours(self, output_path: str): Args: output_path (str): Path where to save the CSV file. """ - with open(output_path, 'w') as f: + with open(output_path, "w") as f: # Write header f.write("slice_idx,x,y\n") # Write data for slice_idx, contour in enumerate(self.contours): if contour is not None: # Skip empty slices - f.write(f"New contour, anterior_endpoint_idx={self.start_end_idx[slice_idx][0]},posterior_endpoint_idx={self.start_end_idx[slice_idx][1]}\n") + f.write( + f"New contour, anterior_endpoint_idx={self.start_end_idx[slice_idx][0]}, " + f"posterior_endpoint_idx={self.start_end_idx[slice_idx][1]}\n" + ) for point in contour: f.write(f"{slice_idx},{point[0]},{point[1]}\n") @@ -1129,33 +1135,33 @@ def load_contours(self, input_path: str): current_points = [] self.contours = [] self.start_end_idx = [] - - with open(input_path, 'r') as f: + + with open(input_path) as f: # Skip header next(f) - + for line in f: - if line.startswith('New contour'): + if line.startswith("New contour"): # If we have points from previous contour, save them if current_points: self.contours.append(np.array(current_points)) current_points = [] - + # Extract anterior and posterior endpoint indices # Format: "New contour, anterior_endpoint_idx=X,posterior_endpoint_idx=Y" - parts = line.strip().split(',') - anterior_idx = int(parts[1].split('=')[1]) - posterior_idx = int(parts[2].split('=')[1]) + parts = line.strip().split(",") + anterior_idx = int(parts[1].split("=")[1]) + posterior_idx = int(parts[2].split("=")[1]) self.start_end_idx.append((anterior_idx, posterior_idx)) else: # Parse point data - slice_idx, x, y = line.strip().split(',') + slice_idx, x, y = line.strip().split(",") current_points.append([float(x), float(y)]) - + # Don't forget to add the last contour if current_points: self.contours.append(np.array(current_points)) - + # Convert lists to fixed-size arrays max_slices = max(len(self.contours), len(self.start_end_idx)) self.contours = self.contours + [None] * (max_slices - len(self.contours)) @@ -1171,7 +1177,7 @@ def save_thickness_values(self, output_path: str): Args: output_path (str): Path where to save the CSV file. """ - with open(output_path, 'w') as f: + with open(output_path, "w") as f: # Write header f.write("slice_idx,thickness\n") # Write data @@ -1196,10 +1202,10 @@ def load_thickness_values(self, input_path: str, original_thickness_vertices_pat ValueError: If the number of thickness values doesn't match the number of measurement points, or if the number of slices is inconsistent. """ - data = np.loadtxt(input_path, delimiter=',', skiprows=1) + data = np.loadtxt(input_path, delimiter=",", skiprows=1) slice_indices = data[:, 0].astype(int) values = data[:, 1] - + # Group values by slice_idx unique_slices = np.unique(slice_indices) @@ -1213,36 +1219,53 @@ def load_thickness_values(self, input_path: str, original_thickness_vertices_pat # check that the number of thickness values for each slice is equal to the number of points in the contour for slice_idx, thickness in enumerate(loaded_thickness_values): if thickness is not None: - assert len(thickness) == len(self.contours[slice_idx]), \ - "Number of thickness values does not match number of points in the contour, maybe you need to provide the measurement points file" + assert len(thickness) == len(self.contours[slice_idx]), ( + "Number of thickness values does not match number of points in the contour, maybe you need to " + "provide the measurement points file" + ) # fill original_thickness_vertices with all indices - self.original_thickness_vertices = [np.arange(len(self.contours[slice_idx])) for slice_idx in range(len(self.contours))] + self.original_thickness_vertices = [ + np.arange(len(self.contours[slice_idx])) for slice_idx in range(len(self.contours)) + ] else: - loaded_original_thickness_vertices = self._load_thickness_measurement_points(original_thickness_vertices_path) + loaded_original_thickness_vertices = self._load_thickness_measurement_points( + original_thickness_vertices_path + ) if len(loaded_original_thickness_vertices) != len(loaded_thickness_values): - raise ValueError("Number of slices in measurement points does not match number of slices in provided thickness values") + raise ValueError( + "Number of slices in measurement points does not match number of " + "slices in provided thickness values" + ) # check that original_thickness_vertices is equal to number of measurement points for each slice for slice_idx, vertex_indices in enumerate(loaded_original_thickness_vertices): - if len(vertex_indices) // 2 == len(loaded_thickness_values[slice_idx]) or len(vertex_indices) // 2 == np.sum(~np.isnan(loaded_thickness_values[slice_idx])): + if len(vertex_indices) // 2 == len(loaded_thickness_values[slice_idx]) or len( + vertex_indices + ) // 2 == np.sum(~np.isnan(loaded_thickness_values[slice_idx])): is_thickness_profile = True - elif len(vertex_indices) == len(loaded_thickness_values[slice_idx]) or len(vertex_indices) == np.sum(~np.isnan(loaded_thickness_values[slice_idx])): + elif len(vertex_indices) == len(loaded_thickness_values[slice_idx]) or len(vertex_indices) == np.sum( + ~np.isnan(loaded_thickness_values[slice_idx]) + ): is_thickness_profile = False else: raise ValueError("Number of measurement points does not match number of thickness values") - # create nan thickness value array for each slice - new_thickness_values = [np.full(len(self.contours[slice_idx]), np.nan) for slice_idx in range(len(self.contours))] + new_thickness_values = [ + np.full(len(self.contours[slice_idx]), np.nan) for slice_idx in range(len(self.contours)) + ] for slice_idx, vertex_indices in enumerate(loaded_original_thickness_vertices): if is_thickness_profile: - new_thickness_values[slice_idx][vertex_indices] = np.concatenate([loaded_thickness_values[slice_idx],loaded_thickness_values[slice_idx][::-1]]) + new_thickness_values[slice_idx][vertex_indices] = np.concatenate( + [loaded_thickness_values[slice_idx], loaded_thickness_values[slice_idx][::-1]] + ) else: - new_thickness_values[slice_idx][vertex_indices] = loaded_thickness_values[slice_idx][~np.isnan(loaded_thickness_values[slice_idx])] + new_thickness_values[slice_idx][vertex_indices] = loaded_thickness_values[slice_idx][ + ~np.isnan(loaded_thickness_values[slice_idx]) + ] self.thickness_values = new_thickness_values - def to_fs_coordinates(self): """Convert mesh coordinates to FreeSurfer coordinate system. @@ -1252,7 +1275,7 @@ def to_fs_coordinates(self): self.v = self.v[:, [2, 0, 1]] self.v[:, 1] -= 128 self.v[:, 2] += 128 - + def write_fssurf(self, filename): """Write the mesh to a FreeSurfer surface file. @@ -1263,7 +1286,7 @@ def write_fssurf(self, filename): The result of the parent class's write_fssurf method. """ return super().write_fssurf(filename) - + def write_overlay(self, filename): """Write the thickness values as a FreeSurfer overlay file. @@ -1274,7 +1297,7 @@ def write_overlay(self, filename): The result of writing the morph data using nibabel. """ return nib.freesurfer.write_morph_data(filename, self.mesh_vertex_colors) - + def save_thickness_measurement_points(self, filename): """Write the thickness measurement points to a CSV file. @@ -1284,7 +1307,7 @@ def save_thickness_measurement_points(self, filename): Args: filename (str): Path where to save the CSV file. """ - with open(filename, 'w') as f: + with open(filename, "w") as f: f.write("slice_idx,vertex_idx\n") for slice_idx, vertex_indices in enumerate(self.original_thickness_vertices): if vertex_indices is not None: @@ -1302,7 +1325,7 @@ def _load_thickness_measurement_points(filename): list: List of arrays containing vertex indices for each slice where thickness was measured. """ - data = np.loadtxt(filename, delimiter=',', skiprows=1) + data = np.loadtxt(filename, delimiter=",", skiprows=1) slice_indices = data[:, 0].astype(int) vertex_indices = data[:, 1].astype(int) @@ -1314,4 +1337,4 @@ def _load_thickness_measurement_points(filename): for slice_idx in unique_slices: mask = slice_indices == slice_idx original_thickness_vertices[slice_idx] = vertex_indices[mask] - return original_thickness_vertices \ No newline at end of file + return original_thickness_vertices diff --git a/CorpusCallosum/shape/cc_metrics.py b/CorpusCallosum/shape/cc_metrics.py index dbb7b7d4..96d7daca 100644 --- a/CorpusCallosum/shape/cc_metrics.py +++ b/CorpusCallosum/shape/cc_metrics.py @@ -1,65 +1,65 @@ import numpy as np + def calculate_cc_index(cc_contour): """ Calculate CC index based on three perpendicular measurements. - + Args: cc_contour: 2xN array of contour points in ACPC space - + Returns: float: Sum of thicknesses at three measurement points """ # Get anterior and posterior points anterior_idx = np.argmin(cc_contour[0]) # Leftmost point posterior_idx = np.argmax(cc_contour[0]) # Rightmost point - + # Get the longest line (anterior to posterior) - ap_line = cc_contour[:,posterior_idx] - cc_contour[:,anterior_idx] + ap_line = cc_contour[:, posterior_idx] - cc_contour[:, anterior_idx] ap_length = np.linalg.norm(ap_line) ap_unit = np.array([-ap_line[1], ap_line[0]]) / ap_length - + # Get midpoint of AP line - midpoint = cc_contour[:,anterior_idx] + (ap_line/2) - + midpoint = cc_contour[:, anterior_idx] + (ap_line / 2) + # Get perpendicular direction - - + # Get intersection points with contour for each measurement line def get_intersections(start_point, direction): # Get all points above and below the line - points = cc_contour.T - start_point[None,:] + points = cc_contour.T - start_point[None, :] dots = np.dot(points, direction) signs = np.sign(dots) sign_changes = np.where(np.diff(signs))[0] - + intersections = [] for idx in sign_changes: # Linear interpolation between points - t = -dots[idx] / (dots[idx+1] - dots[idx]) - intersection = cc_contour[:,idx] + t * (cc_contour[:,idx+1] - cc_contour[:,idx]) + t = -dots[idx] / (dots[idx + 1] - dots[idx]) + intersection = cc_contour[:, idx] + t * (cc_contour[:, idx + 1] - cc_contour[:, idx]) intersections.append(intersection) - + return np.array(intersections) - + # Get three measurements - most_anterior_pt = cc_contour[:,anterior_idx] + most_anterior_pt = cc_contour[:, anterior_idx] perpendicular_unit = np.array([-ap_unit[1], ap_unit[0]]) - - anterior_intersections = get_intersections(most_anterior_pt - 10*perpendicular_unit, ap_unit) + anterior_intersections = get_intersections(most_anterior_pt - 10 * perpendicular_unit, ap_unit) # sort by x - anterior_intersections = anterior_intersections[np.argsort(anterior_intersections[:,0])] + anterior_intersections = anterior_intersections[np.argsort(anterior_intersections[:, 0])] - middle_ints = get_intersections(midpoint, perpendicular_unit) + middle_ints = get_intersections(midpoint, perpendicular_unit) if len(middle_ints) != 2: - print(f"WARNING: The perpendicular line should intersect the contour twice, but it intersects {len(middle_ints)} times") + print( + f"WARNING: The perpendicular line should intersect the contour twice, " + f"but it intersects {len(middle_ints)} times" + ) # plt.close() - - # calculate index ap_distance = np.linalg.norm(anterior_intersections[0] - anterior_intersections[-1]) @@ -69,51 +69,46 @@ def get_intersections(start_point, direction): index = (anterior_distance + posterior_distance + top_distance) / ap_distance - - - # fig, ax = plt.subplots(figsize=(8, 6)) - + # # Plot the CC contour # ax.plot(cc_contour[0], cc_contour[1], 'k-', linewidth=1) # # add line from last to first - # ax.plot([cc_contour[0,-1], cc_contour[0,0]], [cc_contour[1,-1], cc_contour[1,0]], + # ax.plot([cc_contour[0,-1], cc_contour[0,0]], [cc_contour[1,-1], cc_contour[1,0]], # 'k-', linewidth=1) - + # # Plot AP line - # ax.plot([cc_contour[0,anterior_idx], cc_contour[0,posterior_idx]], - # [cc_contour[1,anterior_idx], cc_contour[1,posterior_idx]], + # ax.plot([cc_contour[0,anterior_idx], cc_contour[0,posterior_idx]], + # [cc_contour[1,anterior_idx], cc_contour[1,posterior_idx]], # 'r--', linewidth=1)#, label='Anterior-posterior line') - - + # # Plot the three measurement lines # for i, ints in enumerate(zip(anterior_intersections[:-1], anterior_intersections[1:])): # if i != 1: - # ax.plot([ints[0][0], ints[1][0]], [ints[0][1], ints[1][1]], + # ax.plot([ints[0][0], ints[1][0]], [ints[0][1], ints[1][1]], # 'b-', linewidth=1, label='Measurement line horizontal' if i==0 else None) - - # ax.plot([middle_ints[0,0], middle_ints[1,0]], [middle_ints[0,1], middle_ints[1,1]], - # 'g-', linewidth=1, label='Measurement lines vertical') + # ax.plot([middle_ints[0,0], middle_ints[1,0]], [middle_ints[0,1], middle_ints[1,1]], + # 'g-', linewidth=1, label='Measurement lines vertical') # print(middle_ints[0,], middle_ints[1,1]) # print(midpoint[1], midpoint[0]) - # ax.plot([middle_ints[0,0], midpoint[0]], [middle_ints[0,1], midpoint[1]], + # ax.plot([middle_ints[0,0], midpoint[0]], [middle_ints[0,1], midpoint[1]], # 'r--', linewidth=1)#, label='Superior-inferior line') # #plt.scatter(midpoint[0], midpoint[1], color='green', s=20) - + # ax.set_aspect('equal') # ax.legend() # # add gray background to CC contour # # Fill the inside of the contour with a gray shade # from matplotlib.path import Path # from matplotlib.patches import PathPatch - + # # Create a path from the contour points # contour_path = Path(np.array([cc_contour[0], cc_contour[1]]).T) - + # # Create a patch from the path and add it to the axes # patch = PathPatch(contour_path, facecolor='gray', alpha=0.2, edgecolor=None) # ax.add_patch(patch) @@ -125,54 +120,3 @@ def get_intersections(start_point, direction): # plt.show() return index - -if __name__ == "__main__": - import matplotlib.pyplot as plt - from cc_thickness import convert_to_ras - from shape.cc_endpoint_heuristic import get_endpoints - import pandas as pd - import nibabel as nib - from tqdm import tqdm - # Create visualization of CC index measurements - - - paths_csv = pd.read_csv('/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/network/data/found_labels_with_meta_data_difficult_final.csv', index_col=0) - - - for subj_num, subj_id in enumerate(tqdm(paths_csv.index)): - #subj_id = '099f7f5a' - - label_path = paths_csv.loc[subj_id, 'label_merged'] - - try: - cc_label_nib = nib.load(label_path) - except Exception as e: - import pdb; pdb.set_trace() - print(subj_id, 'error', e) - continue - - PC_2d = paths_csv.loc[subj_id, 'PC_center_r':'PC_center_s'].to_numpy().astype(float)[1:] - AC_2d = paths_csv.loc[subj_id, 'AC_center_r':'AC_center_s'].to_numpy().astype(float)[1:] - - - cc_mask = cc_label_nib.get_fdata() == 192 - cc_mask = cc_mask[cc_mask.shape[0]//2] - - - contour, anterior_endpoint_idx, posterior_endpoint_idx = get_endpoints(cc_mask, AC_2d, PC_2d, cc_label_nib.header.get_zooms()[1], return_coordinates=False) - - - - contour = convert_to_ras(contour, cc_label_nib.affine) - - contour_2d=contour#[[2,0]].T[1:] - #contour = contour[[2,0,1]] - - index = calculate_cc_index(contour_2d) - - print(subj_id, index) - - - - - diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/cc_postprocessing.py index 70fdc154..7063b01b 100644 --- a/CorpusCallosum/shape/cc_postprocessing.py +++ b/CorpusCallosum/shape/cc_postprocessing.py @@ -1,10 +1,14 @@ from pathlib import Path -import nibabel as nib import numpy as np from shape.cc_thickness import convert_to_ras, cc_thickness -from shape.cc_subsegment_contour import subdivide_contour, transform_to_acpc_standard, subsegment_midline_orthogonal, hampel_subdivide_contour +from shape.cc_subsegment_contour import ( + subdivide_contour, + transform_to_acpc_standard, + subsegment_midline_orthogonal, + hampel_subdivide_contour, +) from shape.cc_endpoint_heuristic import get_endpoints from shape.cc_metrics import calculate_cc_index from shape.cc_subsegment_contour import get_primary_eigenvector @@ -21,7 +25,8 @@ LIA_ORIENTATION[2,1] = -1 -def create_visualization(subdivision_method, result, midslices_data, output_image_path, ac_coords, pc_coords, vox_size, title_suffix=""): +def create_visualization(subdivision_method, result, midslices_data, output_image_path, + ac_coords, pc_coords, vox_size, title_suffix=""): """Helper function to create visualization plots based on subdivision method. Args: @@ -44,8 +49,8 @@ def create_visualization(subdivision_method, result, midslices_data, output_imag output_image_path, ac_coords, pc_coords, vox_size, title) else: return run_in_background(plot_contours, False, midslices_data, - None, result['split_contours_hofer_frahm'], result['midline_equidistant'], result['levelpaths'], - output_image_path, ac_coords, pc_coords, vox_size, title) + None, result['split_contours_hofer_frahm'], result['midline_equidistant'], + result['levelpaths'], output_image_path, ac_coords, pc_coords, vox_size, title) def create_slice_affine(temp_seg_affine, slice_idx, fsaverage_middle): @@ -67,7 +72,8 @@ def create_slice_affine(temp_seg_affine, slice_idx, fsaverage_middle): return slice_affine -def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thickness_points, subdivisions, subdivision_method, contour_smoothing, verbose=False): +def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thickness_points, subdivisions, + subdivision_method, contour_smoothing): """Process a single slice for corpus callosum measurements. Performs detailed analysis of a corpus callosum slice, including: @@ -116,27 +122,46 @@ def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thi raise ValueError(f'No CC found in slice {slice_idx}') - contour, anterior_endpoint_idx, posterior_endpoint_idx = get_endpoints(cc_mask_slice, ac_coords, pc_coords, affine.diagonal()[1], return_coordinates=False, contour_smoothing=contour_smoothing) + contour, anterior_endpoint_idx, posterior_endpoint_idx = get_endpoints(cc_mask_slice, ac_coords, pc_coords, + affine.diagonal()[1], + return_coordinates=False, + contour_smoothing=contour_smoothing) contour_1mm = convert_to_ras(contour, affine) - midline_length, thickness, curvature, midline_equidistant, levelpaths, contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx = cc_thickness(contour_1mm.T, anterior_endpoint_idx, posterior_endpoint_idx, n_points=num_thickness_points) - thickness_profile = [np.sum(np.sqrt(np.diff(np.array(levelpath[:,:2]),axis=0)**2),axis=0) for levelpath in levelpaths] + (midline_length, thickness, curvature, midline_equidistant, levelpaths, + contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx) = cc_thickness(contour_1mm.T, + anterior_endpoint_idx, + posterior_endpoint_idx, + n_points=num_thickness_points) + + thickness_profile = [ + np.sum(np.sqrt(np.diff(np.array(levelpath[:,:2]), axis=0)**2), axis=0) + for levelpath in levelpaths + ] thickness_profile = np.linalg.norm(np.array(thickness_profile),axis=1) - contour_acpc, ac_pt_acpc, pc_pt_acpc, rotate_back_acpc = transform_to_acpc_standard(contour_1mm, contour_1mm[:,anterior_endpoint_idx], contour_1mm[:,posterior_endpoint_idx]) + contour_acpc, ac_pt_acpc, pc_pt_acpc, rotate_back_acpc = transform_to_acpc_standard(contour_1mm, + contour_1mm[:,anterior_endpoint_idx], + contour_1mm[:,posterior_endpoint_idx]) cc_index = calculate_cc_index(contour_acpc) # Apply different subdivision methods based on user choice if subdivision_method == "shape": - areas, split_contours = subsegment_midline_orthogonal(midline_equidistant, subdivisions, contour_1mm, plot=False) - split_contours = [transform_to_acpc_standard(split_contour, contour_1mm[:,anterior_endpoint_idx], contour_1mm[:,posterior_endpoint_idx])[0] for split_contour in split_contours] + areas, split_contours = subsegment_midline_orthogonal(midline_equidistant, subdivisions, + contour_1mm, plot=False) + split_contours = [transform_to_acpc_standard(split_contour, + contour_1mm[:,anterior_endpoint_idx], + contour_1mm[:,posterior_endpoint_idx])[0] + for split_contour in split_contours] + split_contours_hofer_frahm = None elif subdivision_method == "vertical": areas, split_contours = subdivide_contour(contour_acpc, subdivisions, plot=False) split_contours_hofer_frahm = split_contours.copy() elif subdivision_method == "angular": if not np.allclose(np.diff(subdivisions), np.diff(subdivisions)[0]): - print('Error: Angular subdivision method (Hampel) only supports equidistant subdivision, but got: ', subdivisions) + print('Error: Angular subdivision method (Hampel) only supports equidistant subdivision, ' + f'but got: {subdivisions}') return None areas, split_contours = hampel_subdivide_contour(contour_acpc, num_rays=len(subdivisions), plot=False) split_contours_hofer_frahm = split_contours.copy() @@ -217,10 +242,17 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac slice_idx = segmentation.shape[0] // 2 slice_affine = create_slice_affine(temp_seg_affine, slice_idx, FSAVERAGE_MIDDLE) - result, contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx = process_slice(segmentation, slice_idx, ac_coords, pc_coords, slice_affine, + result, contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx = process_slice(segmentation, + slice_idx, + ac_coords, + pc_coords, + slice_affine, num_thickness_points, subdivisions, subdivision_method, contour_smoothing) - cc_mesh.add_contour(0, contour_with_thickness[0], contour_with_thickness[1], start_end_idx=(anterior_endpoint_idx, posterior_endpoint_idx)) + cc_mesh.add_contour(0, + contour_with_thickness[0], + contour_with_thickness[1], + start_end_idx=(anterior_endpoint_idx, posterior_endpoint_idx)) if result is not None: slice_results.append(result) @@ -250,16 +282,18 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac slice_affine = create_slice_affine(temp_seg_affine, slice_idx, FSAVERAGE_MIDDLE) # Process this slice - result, contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx = process_slice(segmentation, slice_idx, ac_coords, pc_coords, slice_affine, - num_thickness_points, subdivisions, subdivision_method, contour_smoothing) + (result, contour_with_thickness, + anterior_endpoint_idx, posterior_endpoint_idx) = process_slice(segmentation, slice_idx, + ac_coords, pc_coords, + slice_affine, num_thickness_points, + subdivisions, subdivision_method, + contour_smoothing) # insert - cc_mesh.add_contour(slice_idx, contour_with_thickness[0], contour_with_thickness[1], start_end_idx=(anterior_endpoint_idx, posterior_endpoint_idx)) - #cc_mesh.plot_contour(slice_idx, output_dir / f'slice_{slice_idx}' / 'contour.png') - - #cc_mesh.plot_contour(slice_idx, output_dir / f'slice_{slice_idx}' / 'contour_filled.png') - - + cc_mesh.add_contour(slice_idx, + contour_with_thickness[0], + contour_with_thickness[1], + start_end_idx=(anterior_endpoint_idx, posterior_endpoint_idx)) if result is not None: slice_results.append(result) @@ -274,7 +308,8 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac # Create visualization for this slice IO_processes.append(create_visualization(subdivision_method, result, midslices[slice_idx:slice_idx+1], - output_subdir, ac_coords, pc_coords, vox_size, f' (Slice {slice_idx})')) + output_subdir, ac_coords, pc_coords, + vox_size, f' (Slice {slice_idx})')) if save_template is not None: # Convert to Path object and ensure directory exists diff --git a/CorpusCallosum/shape/cc_subsegment_contour.py b/CorpusCallosum/shape/cc_subsegment_contour.py index 4a754793..977fa353 100644 --- a/CorpusCallosum/shape/cc_subsegment_contour.py +++ b/CorpusCallosum/shape/cc_subsegment_contour.py @@ -1,12 +1,5 @@ -import nibabel as nib import numpy as np from scipy.spatial import ConvexHull -from tqdm import tqdm -import pandas as pd - -from shape.cc_thickness import cc_thickness, convert_to_ras -from shape.cc_endpoint_heuristic import get_endpoints - def minimum_bounding_rectangle(points): @@ -17,28 +10,23 @@ def minimum_bounding_rectangle(points): :param points: an nx2 matrix of coordinates :rval: an nx2 matrix of coordinates """ - pi2 = np.pi/2. + pi2 = np.pi / 2.0 points = points.T # get the convex hull for the points hull_points = points[ConvexHull(points).vertices] # calculate edge angles - edges = np.zeros((len(hull_points)-1, 2)) + edges = np.zeros((len(hull_points) - 1, 2)) edges = hull_points[1:] - hull_points[:-1] - angles = np.zeros((len(edges))) angles = np.arctan2(edges[:, 1], edges[:, 0]) angles = np.abs(np.mod(angles, pi2)) angles = np.unique(angles) # find rotation matrices - rotations = np.vstack([ - np.cos(angles), - np.cos(angles-pi2), - np.cos(angles+pi2), - np.cos(angles)]).T + rotations = np.vstack([np.cos(angles), np.cos(angles - pi2), np.cos(angles + pi2), np.cos(angles)]).T rotations = rotations.reshape((-1, 2, 2)) # apply rotations to the hull @@ -66,7 +54,7 @@ def minimum_bounding_rectangle(points): rval[1] = np.dot([x2, y2], r) rval[2] = np.dot([x2, y1], r) rval[3] = np.dot([x1, y1], r) - + return rval @@ -75,45 +63,41 @@ def get_area_from_subsegments(split_contours): areas = [np.abs(np.trapz(split_contour[1], split_contour[0])) for split_contour in split_contours] area_out = np.zeros(len(areas)) for i in range(len(areas)): - if i == len(areas)-1: + if i == len(areas) - 1: area_out[i] = areas[i] else: - area_out[i] = areas[i] - areas[i+1] + area_out[i] = areas[i] - areas[i + 1] return area_out def subsegment_midline_orthogonal(midline, area_weights, contour, plot=True, ax=None, extremes=None): - # get points after midline length of splits # get vertex closest to midline end midline_end_idx = np.argmin(np.linalg.norm(contour.T - midline[-1], axis=1)) # roll contour start to midline end contour = np.roll(contour, -midline_end_idx, axis=1) - + edge_idx, edge_frac = np.divmod(len(midline) * np.array(area_weights), 1) edge_idx = edge_idx.astype(int) - split_points = midline[edge_idx] + (midline[edge_idx+1] - midline[edge_idx]) * edge_frac[:,None] + split_points = midline[edge_idx] + (midline[edge_idx + 1] - midline[edge_idx]) * edge_frac[:, None] # get edge for each split point - edge_directions = midline[edge_idx] - midline[edge_idx+1] + edge_directions = midline[edge_idx] - midline[edge_idx + 1] # get vector perpendicular to each midline edge - edge_ortho_vectors = np.column_stack((-edge_directions[:,1], edge_directions[:,0])) - edge_ortho_vectors = edge_ortho_vectors / np.linalg.norm(edge_ortho_vectors, axis=1)[:,None] - + edge_ortho_vectors = np.column_stack((-edge_directions[:, 1], edge_directions[:, 0])) + edge_ortho_vectors = edge_ortho_vectors / np.linalg.norm(edge_ortho_vectors, axis=1)[:, None] + split_contours = [] split_contours.append(contour) - - - for pt_idx,split_point in enumerate(split_points): + + for pt_idx, split_point in enumerate(split_points): intersections = [] for i in range(contour.shape[1] - 1): - # get contour segment segment_start = contour[:, i] segment_end = contour[:, i + 1] segment_vector = segment_end - segment_start - # Check for intersection with the perpendicular line matrix = np.array([segment_vector, -edge_ortho_vectors[pt_idx]]).T @@ -131,12 +115,13 @@ def subsegment_midline_orthogonal(midline, area_weights, contour, plot=True, ax= # plt.plot(contour[0], contour[1], 'k-') # plt.plot(midline[:,0], midline[:,1], 'k--') # plt.plot(split_point[0], split_point[1], 'ro') - + # plt.plot([segment_start[0], segment_end[0]], [segment_start[1], segment_end[1]], 'bo', linewidth=2) - # plt.plot([split_point[0]-edge_ortho_vectors[pt_idx][0], split_point[0]+edge_ortho_vectors[pt_idx][0]], [split_point[1]-edge_ortho_vectors[pt_idx][1], split_point[1]+edge_ortho_vectors[pt_idx][1]], 'k-', linewidth=2) + # plt.plot([split_point[0]-edge_ortho_vectors[pt_idx][0], split_point[0]+edge_ortho_vectors[pt_idx][0]], + # [split_point[1]-edge_ortho_vectors[pt_idx][1], + # split_point[1]+edge_ortho_vectors[pt_idx][1]], 'k-', linewidth=2) # plt.show() - # get the two intersections closest to split_point intersections.sort(key=lambda x: np.linalg.norm(x[1] - split_point)) @@ -150,130 +135,133 @@ def subsegment_midline_orthogonal(midline, area_weights, contour, plot=True, ax= first_intersection, second_intersection = second_intersection, first_intersection first_index += 1 - #second_index += 1 + # second_index += 1 # connect first and second half - start_to_cutoff = np.hstack((contour[:, :first_index], first_intersection[:, None], second_intersection[:, None], contour[:, second_index + 1:])) + start_to_cutoff = np.hstack( + ( + contour[:, :first_index], + first_intersection[:, None], + second_intersection[:, None], + contour[:, second_index + 1 :], + ) + ) split_contours.append(start_to_cutoff) else: - raise ValueError('No intersections found, this should not happen') - + raise ValueError("No intersections found, this should not happen") + # plot contour to first index, then split point, then contour to second index - + # import matplotlib.pyplot as plt # plt.close() # fig, ax = plt.subplots(1,1) - # ax.plot(contour[:, :first_index][0], contour[:, :first_index][1], '-', linewidth=2, color='grey', label='Contour to first index') - # ax.plot(first_intersection[0], first_intersection[1], 'o', markersize=8, color='red', label='First intersection') - # ax.plot(second_intersection[0], second_intersection[1], 'o', markersize=8, color='red', label='Second intersection') - # ax.plot(contour[:, second_index + 1:][0], contour[:, second_index + 1:][1], '-', linewidth=2, color='red', label='Contour to second index') + # ax.plot(contour[:, :first_index][0], contour[:, :first_index][1], '-', linewidth=2, color='grey', + # label='Contour to first index') + # ax.plot(first_intersection[0], first_intersection[1], 'o', markersize=8, color='red', + # label='First intersection') + # ax.plot(second_intersection[0], second_intersection[1], 'o', markersize=8, color='red', + # label='Second intersection') + # ax.plot(contour[:, second_index + 1:][0], contour[:, second_index + 1:][1], '-', linewidth=2, color='red', + # label='Contour to second index') # ax.legend() # ax.set_title('Split Contours') # ax.set_aspect('equal') # ax.axis('off') # plt.show() - - - + if plot: extremes = [midline[0], midline[-1]] - + plot_transform = None if plot_transform is not None: split_contours = [plot_transform(split_contour) for split_contour in split_contours] contour = plot_transform(contour) - extremes = [plot_transform(extreme[:,None]) for extreme in extremes] - split_points = [plot_transform(split_point[:,None]) for split_point in split_points] - split_points_vlines_start = plot_transform(split_points_vlines_start) - split_points_vlines_end = plot_transform(split_points_vlines_end) - - import seaborn as sns + extremes = [plot_transform(extreme[:, None]) for extreme in extremes] + split_points = [plot_transform(split_point[:, None]) for split_point in split_points] + # split_points_vlines_start = plot_transform(split_points_vlines_start) + # split_points_vlines_end = plot_transform(split_points_vlines_end) + import matplotlib.pyplot as plt + if ax is None: SHOW = True - fig, ax = plt.subplots(1,1,figsize=(8, 6)) - ax.axis('equal') + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + ax.axis("equal") else: SHOW = False - # pretty plot with areas filles in the polygon and overall area annotated - colors = sns.color_palette("ch:start=.2,rot=-.3", len(split_contours)) - for i, (color, split_contour) in enumerate(zip(colors, split_contours)): + # pretty plot with areas filled in the polygon and overall area annotated + colors = plt.cm.Spectral(np.linspace(0.2, 0.8, len(split_contours))) + for color, split_contour in zip(colors, split_contours, strict=True): ax.fill(split_contour[0], split_contour[1], alpha=0.5, color=color) - #ax.text(np.mean(split_contour[0]), np.mean(split_contour[1]), f'{area_out[i]:.2f}', color='black', fontsize=12) + # ax.text(np.mean(split_contour[0]), np.mean(split_contour[1]), f'{area_out[i]:.2f}', + # olor='black', fontsize=12) # plot contour - ax.plot(contour[0], contour[1], '-', linewidth=2, color='grey') + ax.plot(contour[0], contour[1], "-", linewidth=2, color="grey") # put text between split points # add enpoints to split_points split_points = split_points.tolist() split_points.insert(0, extremes[0]) split_points.append(extremes[1]) - #ax.scatter(np.array(split_points)[:,0], np.array(split_points)[:,1], color='black', s=20) - ax.plot(midline[:,0], midline[:,1], 'k--', linewidth=2) - + # ax.scatter(np.array(split_points)[:,0], np.array(split_points)[:,1], color='black', s=20) + ax.plot(midline[:, 0], midline[:, 1], "k--", linewidth=2) # plot edge orthogonal to each split point - for i in range(0,len(edge_ortho_vectors)): - pt = split_points[i+1] + for i in range(0, len(edge_ortho_vectors)): + pt = split_points[i + 1] length = 0.4 - ax.plot([pt[0]-edge_ortho_vectors[i][0]*length, pt[0]+edge_ortho_vectors[i][0]*length], [pt[1]-edge_ortho_vectors[i][1]*length, pt[1]+edge_ortho_vectors[i][1]*length], 'k-', linewidth=2) + ax.plot( + [pt[0] - edge_ortho_vectors[i][0] * length, pt[0] + edge_ortho_vectors[i][0] * length], + [pt[1] - edge_ortho_vectors[i][1] * length, pt[1] + edge_ortho_vectors[i][1] * length], + "k-", + linewidth=2, + ) # convert area_weights into fraction of total line length # e.g. area_weights=[1/6, 1/2, 2/3, 3/4] to ['1/6', '2/3', ...] # cumulative difference area_weights_diff = [] area_weights_diff.append(area_weights[0]) - for i in range(1,len(area_weights)): - area_weights_diff.append(area_weights[i] - area_weights[i-1]) + for i in range(1, len(area_weights)): + area_weights_diff.append(area_weights[i] - area_weights[i - 1]) area_weights_diff.append(1 - area_weights[-1]) - - #area_weights_txt = area_weights_txt / area_weights_txt[-1] - from fractions import Fraction - area_weights_txt = [Fraction(area_weights_diff[i]).limit_denominator(1000) for i in range(len(area_weights_diff))] - - for i in range(len(split_points)-1): + + for i in range(len(split_points) - 1): # get_index of split_points[i] in midline sp1_midline_idx = np.argmin(np.linalg.norm(midline - split_points[i], axis=1)) - sp2_midline_idx = np.argmin(np.linalg.norm(midline - split_points[i+1], axis=1)) + sp2_midline_idx = np.argmin(np.linalg.norm(midline - split_points[i + 1], axis=1)) # get midpoint on midline midpoint_idx = (sp1_midline_idx + sp2_midline_idx) // 2 midpoint = midline[midpoint_idx] # get vector perpendicular to line between split points - vector = np.array(split_points[i+1]) - np.array(split_points[i]) + vector = np.array(split_points[i + 1]) - np.array(split_points[i]) vector = vector / np.linalg.norm(vector) vector = np.array([-vector[1], vector[0]]) midpoint = midpoint - vector * 3 - #ax.text(midpoint[0]-5, midpoint[1]-5, f'{area_out[i]:.2f}', color='black', fontsize=12) - #ax.text(midpoint[0], midpoint[1], f'{area_weights_txt[i]}', color='black', fontsize=12, horizontalalignment='center', verticalalignment='center') + # ax.text(midpoint[0]-5, midpoint[1]-5, f'{area_out[i]:.2f}', color='black', fontsize=12) + # ax.text(midpoint[0], midpoint[1], f'{area_weights_txt[i]}', color='black', fontsize=12, + # horizontalalignment='center', verticalalignment='center') - - # start point & end point - ax.plot(extremes[0][0], extremes[0][1], marker='o', markersize=8, color='black') - ax.plot(extremes[1][0], extremes[1][1], marker='o', markersize=8, color='black') - + ax.plot(extremes[0][0], extremes[0][1], marker="o", markersize=8, color="black") + ax.plot(extremes[1][0], extremes[1][1], marker="o", markersize=8, color="black") # plot contour point 0 - #ax.scatter(contour[0,0], contour[1,0], color='red', s=120) - ax.set_title('Split Contours') - + # ax.scatter(contour[0,0], contour[1,0], color='red', s=120) + ax.set_title("Split Contours") + if SHOW: - ax.axis('off') + ax.axis("off") ax.invert_xaxis() - ax.axis('equal') + ax.axis("equal") plt.show() - return get_area_from_subsegments(split_contours), split_contours - - - def hampel_subdivide_contour(contour, num_rays, plot=False, ax=None): - # Find the extreme points in the x-direction min_x_index = np.argmin(contour[0]) contour = np.roll(contour, -min_x_index, axis=1) @@ -287,49 +275,47 @@ def hampel_subdivide_contour(contour, num_rays, plot=False, ax=None): long_edges = np.linalg.norm(long_edges, axis=1) long_edges_idx = np.argpartition(long_edges, -2)[-2:] - # select lower long edge min_val = np.inf min_idx = None for i in long_edges_idx: - - if rectangle_duplicate_last[i][1] < min_val: + if rectangle_duplicate_last[i][1] < min_val: min_val = rectangle_duplicate_last[i][1] min_idx = i - - if rectangle_duplicate_last[i+1][1] < min_val: - min_val = rectangle_duplicate_last[i+1][1] + + if rectangle_duplicate_last[i + 1][1] < min_val: + min_val = rectangle_duplicate_last[i + 1][1] min_idx = i - lowest_points = rectangle_duplicate_last[[min_idx, min_idx+1]] + lowest_points = rectangle_duplicate_last[[min_idx, min_idx + 1]] # sort lowest points by x coordinate - if lowest_points[0,0] < lowest_points[1,0]: + if lowest_points[0, 0] < lowest_points[1, 0]: lowest_points = lowest_points[::-1] - - + # get midpoint of lower edge of rectangle midpoint_lower_edge = np.mean(lowest_points, axis=0) - + # get angle of lower edge of rectangle to x-axis - angle_lower_edge = np.arctan2(lowest_points[1, 1] - lowest_points[0, 1], lowest_points[1, 0] - lowest_points[0, 0]) #% (np.pi) - - #steps = np.pi / num_rays - - #print(np.degrees(angle_lower_edge)) + angle_lower_edge = np.arctan2( + lowest_points[1, 1] - lowest_points[0, 1], lowest_points[1, 0] - lowest_points[0, 0] + ) # % (np.pi) + + # steps = np.pi / num_rays + + # print(np.degrees(angle_lower_edge)) # get angles for equally spaced rays - angles = np.linspace(-angle_lower_edge, -angle_lower_edge + np.pi, num_rays+2, endpoint=True) #+ np.pi *3 + angles = np.linspace(-angle_lower_edge, -angle_lower_edge + np.pi, num_rays + 2, endpoint=True) # + np.pi *3 angles = angles[1:-1] - + # create ray vectors ray_vectors = np.vstack((np.cos(angles), np.sin(angles))) # make ray vectors unit length ray_vectors = ray_vectors / np.linalg.norm(ray_vectors, axis=0) - + # invert x of ray vectors ray_vectors[0] = -ray_vectors[0] - - + # Subdivision logic split_contours = [] for ray_vector in ray_vectors.T: @@ -352,64 +338,71 @@ def hampel_subdivide_contour(contour, num_rays, plot=False, ax=None): # Sort intersections by their position along the contour intersections.sort() - # Create new contours by splitting at intersections if intersections: first_index, first_intersection = intersections[0] second_index, second_intersection = intersections[-1] - start_to_cutoff = np.hstack((contour[:, :first_index], first_intersection[:, None], second_intersection[:, None], contour[:, second_index + 1:])) - + start_to_cutoff = np.hstack( + ( + contour[:, :first_index], + first_intersection[:, None], + second_intersection[:, None], + contour[:, second_index + 1 :], + ) + ) + # connect first and second half split_contours.append(start_to_cutoff) else: - raise ValueError('No intersections found, this should not happen') - + raise ValueError("No intersections found, this should not happen") + split_contours.append(contour) split_contours = split_contours[::-1] - - - #split_contours = split_contours[::-1] + + # split_contours = split_contours[::-1] # Plotting logic if plot: - import seaborn as sns import matplotlib.pyplot as plt + if ax is None: fig, ax = plt.subplots(1, 1, figsize=(8, 6)) - ax.axis('equal') + ax.axis("equal") SHOW = True else: SHOW = False min_bounding_rectangle_plot = np.vstack((min_bounding_rectangle, min_bounding_rectangle[0])) - #ax.plot(contour[0], contour[1], 'b-', label='Original Contour') - ax.plot(min_bounding_rectangle_plot[:,0], min_bounding_rectangle_plot[:,1], 'k--') - ax.plot(midpoint_lower_edge[0], midpoint_lower_edge[1], 'ko', markersize=8) - for i, ray_vector in enumerate(ray_vectors.T): + # ax.plot(contour[0], contour[1], 'b-', label='Original Contour') + ax.plot(min_bounding_rectangle_plot[:, 0], min_bounding_rectangle_plot[:, 1], "k--") + ax.plot(midpoint_lower_edge[0], midpoint_lower_edge[1], "ko", markersize=8) + for ray_vector in ray_vectors.T: ray_length = 15 ray_vector *= -ray_length - ax.plot([midpoint_lower_edge[0], midpoint_lower_edge[0] + ray_vector[0]], - [midpoint_lower_edge[1], midpoint_lower_edge[1] + ray_vector[1]], - 'k--') + ax.plot( + [midpoint_lower_edge[0], midpoint_lower_edge[0] + ray_vector[0]], + [midpoint_lower_edge[1], midpoint_lower_edge[1] + ray_vector[1]], + "k--", + ) # pretty plot with areas filles in the polygon and overall area annotated - colors = sns.color_palette("ch:start=.2,rot=-.3", len(split_contours)) - for i, (color, split_contour) in enumerate(zip(colors, split_contours)): + colors = plt.cm.Spectral(np.linspace(0.2, 0.8, len(split_contours))) + for color, split_contour in zip(colors, split_contours, strict=True): ax.fill(split_contour[0], split_contour[1], alpha=0.5, color=color) - ax.plot(contour[0], contour[1], '-', linewidth=2, color='grey') - - ax.set_title('Split Contours') - ax.axis('off') + ax.plot(contour[0], contour[1], "-", linewidth=2, color="grey") + + ax.set_title("Split Contours") + ax.axis("off") if SHOW: - ax.axis('equal') + ax.axis("equal") plt.show() - return get_area_from_subsegments(split_contours), split_contours -def subdivide_contour(contour, area_weights, plot=False, ax=None, plot_transform=None, oriented=False, hline_anchor=None): - +def subdivide_contour( + contour, area_weights, plot=False, ax=None, plot_transform=None, oriented=False, hline_anchor=None +): # Find the extreme points in the x-direction min_x_index = np.argmax(contour[0]) contour = np.roll(contour, -min_x_index, axis=1) @@ -417,11 +410,7 @@ def subdivide_contour(contour, area_weights, plot=False, ax=None, plot_transform min_x_index = 0 max_x_index = np.argmin(contour[0]) - - - if oriented: - contour_x_sorted = np.sort(contour[0]) min_x = contour_x_sorted[0] max_x = contour_x_sorted[-1] @@ -430,16 +419,16 @@ def subdivide_contour(contour, area_weights, plot=False, ax=None, plot_transform if hline_anchor is not None: extremes = (np.array([max_x, hline_anchor[1]]), np.array([min_x, hline_anchor[1]])) - - # only keep x values of extremes and set y 5 mm below most inferior point of contour # if hline_anchor is None: # most_inferior_point = np.min(contour[1]) - # extremes = (np.array([extremes[0][0], most_inferior_point - 5]), np.array([extremes[1][0], most_inferior_point - 5])) + # extremes = (np.array([extremes[0][0], most_inferior_point - 5]), + # np.array([extremes[1][0], most_inferior_point - 5])) # else: # # get y diffrence between extremes and hline_anchor # y_diff = extremes[1][1] - hline_anchor[1] - # extremes = (np.array([extremes[0][0], extremes[0][1] - y_diff]), np.array([extremes[1][0], extremes[1][1] - y_diff])) + # extremes = (np.array([extremes[0][0], extremes[0][1] - y_diff]), + # np.array([extremes[1][0], extremes[1][1] - y_diff])) else: extremes = (contour[:, min_x_index].copy(), contour[:, max_x_index].copy()) # Calculate the line between the extreme points @@ -453,10 +442,11 @@ def subdivide_contour(contour, area_weights, plot=False, ax=None, plot_transform # Calculate the perpendicular vector perp_vector = np.array([-line_unit_vector[1], line_unit_vector[0]]) perp_vector = perp_vector / np.linalg.norm(perp_vector) - + if hline_anchor is None: most_inferior_point = np.min(contour[1]) - # move extreme 1 down 5 mm below inferior point and extreme 2 the same distance (so the angle stays the same) + # move extreme 1 down 5 mm below inferior point and extreme 2 the + # same distance (so the angle stays the same) down_distance = (extremes[1][1] - most_inferior_point) * 1.3 start_point = extremes[0] + down_distance * perp_vector end_point = extremes[1] + down_distance * perp_vector @@ -481,12 +471,8 @@ def subdivide_contour(contour, area_weights, plot=False, ax=None, plot_transform # move start and end point the same distance start_point = extremes[0] + distance * perp_vector end_point = extremes[1] + distance * perp_vector - - - extremes = (start_point, end_point) - - + extremes = (start_point, end_point) # Calculate the line between the extreme points start_point, end_point = extremes @@ -501,13 +487,11 @@ def subdivide_contour(contour, area_weights, plot=False, ax=None, plot_transform # Calculate split points based on area weights split_points = [] - for i,weight in enumerate(area_weights): - #current_weight = np.sum(area_weights[:i]) + for weight in area_weights: + # current_weight = np.sum(area_weights[:i]) split_distance = weight * line_length split_point = start_point + split_distance * line_unit_vector split_points.append(split_point) - - # Split the contour at the calculated split points split_contours = [] @@ -531,14 +515,13 @@ def subdivide_contour(contour, area_weights, plot=False, ax=None, plot_transform intersections.append((i, intersection_point)) # Sort intersections by their position along the contour - #intersections.sort() - + # intersections.sort() + # get the two intersections that have the highest y coordinate intersections.sort(key=lambda x: x[1][1], reverse=True) # Create new contours by splitting at intersections if intersections: - first_index, first_intersection = intersections[1] second_index, second_intersection = intersections[0] @@ -547,13 +530,13 @@ def subdivide_contour(contour, area_weights, plot=False, ax=None, plot_transform first_intersection, second_intersection = second_intersection, first_intersection first_index += 1 - #second_index += 1 - - - - #start_to_cutoff = np.hstack((contour[:, :first_index], first_intersection[:, None], second_intersection[:, None], contour[:, second_index + 1:])) - start_to_cutoff = np.hstack((first_intersection[:, None], contour[:, first_index:second_index], second_intersection[:, None])) + # second_index += 1 + # start_to_cutoff = np.hstack((contour[:, :first_index], first_intersection[:, None], + # second_intersection[:, None], contour[:, second_index + 1:])) + start_to_cutoff = np.hstack( + (first_intersection[:, None], contour[:, first_index:second_index], second_intersection[:, None]) + ) # import matplotlib.pyplot as plt # plt.close() @@ -564,17 +547,16 @@ def subdivide_contour(contour, area_weights, plot=False, ax=None, plot_transform # ax.plot(second_intersection[0], second_intersection[1], 'go', label='Second Intersection') # # ax.plot(contour[:, :first_index][0], contour[:, :first_index][1]+0.5, 'r-', label='First Segment') # # ax.plot(contour[:, second_index+1:][0], contour[:, second_index+1:][1]+1, 'g-', label='Second Segment') - # ax.plot(contour[:, first_index:second_index][0], contour[:, first_index:second_index][1]+0.5, 'r-', label='Segment') + # ax.plot(contour[:, first_index:second_index][0], + # contour[:, first_index:second_index][1]+0.5, 'r-', label='Segment') # ax.plot(start_to_cutoff[0], start_to_cutoff[1], 'g-', label='Start to Cutoff') # ax.legend() # plt.show() - + # connect first and second half split_contours.append(start_to_cutoff) else: - raise ValueError('No intersections found, this should not happen') - - + raise ValueError("No intersections found, this should not happen") # if plot: # import matplotlib.pyplot as plt @@ -591,7 +573,7 @@ def subdivide_contour(contour, area_weights, plot=False, ax=None, plot_transform # plt.legend() # plt.axis('equal') # plt.show() - + # # same plot but segment are moved apart by 5 mm # plt.figure(figsize=(8, 6)) # for i, split_contour in enumerate(split_contours): @@ -603,58 +585,66 @@ def subdivide_contour(contour, area_weights, plot=False, ax=None, plot_transform # plt.legend() # plt.axis('equal') # plt.show() - - - - + if plot: # make vline at every split point split_points_vlines_start = (np.array(split_points) - perp_vector * 1).T split_points_vlines_end = (np.array(split_points) + perp_vector * 1).T - - if oriented: - # make another vline at start point and end point, this time not perpendicular to line but perpendicular to x-axis - start_point_vline = np.array([start_point, np.array([start_point[0], start_point[1]+8])]) - end_point_vline = np.array([end_point, np.array([end_point[0], end_point[1]+8])]) + if oriented: + # make another vline at start point and end point, this time not + # perpendicular to line but perpendicular to x-axis + start_point_vline = np.array([start_point, np.array([start_point[0], start_point[1] + 8])]) + end_point_vline = np.array([end_point, np.array([end_point[0], end_point[1] + 8])]) else: start_point_vline = np.array([start_point, start_point - perp_vector * 8]) end_point_vline = np.array([end_point, end_point - perp_vector * 8]) - + if plot_transform is not None: split_contours = [plot_transform(split_contour) for split_contour in split_contours] contour = plot_transform(contour) - extremes = [plot_transform(extreme[:,None]) for extreme in extremes] - split_points = [plot_transform(split_point[:,None]) for split_point in split_points] + extremes = [plot_transform(extreme[:, None]) for extreme in extremes] + split_points = [plot_transform(split_point[:, None]) for split_point in split_points] split_points_vlines_start = plot_transform(split_points_vlines_start) split_points_vlines_end = plot_transform(split_points_vlines_end) start_point_vline = plot_transform(start_point_vline.T).T end_point_vline = plot_transform(end_point_vline.T).T - - import seaborn as sns + import matplotlib.pyplot as plt + if ax is None: SHOW = True - fig, ax = plt.subplots(1,1,figsize=(8, 6)) - ax.axis('equal') + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + ax.axis("equal") else: SHOW = False # pretty plot with areas filles in the polygon and overall area annotated - colors = sns.color_palette("ch:start=.2,rot=-.3", len(split_contours)) - for i, (color, split_contour) in enumerate(zip(colors, split_contours)): + colors = plt.cm.Spectral(np.linspace(0.2, 0.8, len(split_contours))) + for color, split_contour in zip(colors, split_contours, strict=True): ax.fill(split_contour[0], split_contour[1], alpha=0.5, color=color) - #ax.text(np.mean(split_contour[0]), np.mean(split_contour[1]), f'{area_out[i]:.2f}', color='black', fontsize=12) + # ax.text(np.mean(split_contour[0]), np.mean(split_contour[1]), + # f'{area_out[i]:.2f}', color='black', fontsize=12) # plot contour - ax.plot(contour[0], contour[1], '-', linewidth=2, color='grey') + ax.plot(contour[0], contour[1], "-", linewidth=2, color="grey") # dashed line between start point & end point - ax.plot(np.vstack((extremes[0][0], extremes[1][0])), np.vstack((extremes[0][1], extremes[1][1])), '--', linewidth=2, color='grey') + ax.plot( + np.vstack((extremes[0][0], extremes[1][0])), + np.vstack((extremes[0][1], extremes[1][1])), + "--", + linewidth=2, + color="grey", + ) # markers at every split point for i in range(split_points_vlines_start.shape[1]): - ax.plot(np.vstack((split_points_vlines_start[:,i][0], split_points_vlines_end[:,i][0])), - np.vstack((split_points_vlines_start[:,i][1], split_points_vlines_end[:,i][1])), 'k-', linewidth=2) - - ax.plot(start_point_vline[:,0], start_point_vline[:,1], '--', linewidth=2, color='grey') - ax.plot(end_point_vline[:,0], end_point_vline[:,1], '--', linewidth=2, color='grey') + ax.plot( + np.vstack((split_points_vlines_start[:, i][0], split_points_vlines_end[:, i][0])), + np.vstack((split_points_vlines_start[:, i][1], split_points_vlines_end[:, i][1])), + "k-", + linewidth=2, + ) + + ax.plot(start_point_vline[:, 0], start_point_vline[:, 1], "--", linewidth=2, color="grey") + ax.plot(end_point_vline[:, 0], end_point_vline[:, 1], "--", linewidth=2, color="grey") # put text between split points # add enpoints to split_points split_points.insert(0, extremes[0]) @@ -664,72 +654,71 @@ def subdivide_contour(contour, area_weights, plot=False, ax=None, plot_transform # cumulative difference area_weights_diff = [] area_weights_diff.append(area_weights[0]) - for i in range(1,len(area_weights)): - area_weights_diff.append(area_weights[i] - area_weights[i-1]) + for i in range(1, len(area_weights)): + area_weights_diff.append(area_weights[i] - area_weights[i - 1]) area_weights_diff.append(1 - area_weights[-1]) - - #area_weights_txt = area_weights_txt / area_weights_txt[-1] + + # area_weights_txt = area_weights_txt / area_weights_txt[-1] from fractions import Fraction - area_weights_txt = [Fraction(area_weights_diff[i]).limit_denominator(1000) for i in range(len(area_weights_diff))] - - for i in range(len(split_points)-1): - midpoint = np.mean([split_points[i], split_points[i+1]], axis=0) - #ax.text(midpoint[0]-5, midpoint[1]-5, f'{area_out[i]:.2f}', color='black', fontsize=12) - ax.text(midpoint[0], midpoint[1]-5, f'{area_weights_txt[i]}', color='black', fontsize=11, horizontalalignment='center') - - - - # start point & end point - ax.plot(extremes[0][0], extremes[0][1], marker='o', markersize=8, color='black') - ax.plot(extremes[1][0], extremes[1][1], marker='o', markersize=8, color='black') + area_weights_txt = [ + Fraction(area_weights_diff[i]).limit_denominator(1000) for i in range(len(area_weights_diff)) + ] + + for i in range(len(split_points) - 1): + midpoint = np.mean([split_points[i], split_points[i + 1]], axis=0) + # ax.text(midpoint[0]-5, midpoint[1]-5, f'{area_out[i]:.2f}', color='black', fontsize=12) + ax.text( + midpoint[0], + midpoint[1] - 5, + f"{area_weights_txt[i]}", + color="black", + fontsize=11, + horizontalalignment="center", + ) + + # start point & end point + ax.plot(extremes[0][0], extremes[0][1], marker="o", markersize=8, color="black") + ax.plot(extremes[1][0], extremes[1][1], marker="o", markersize=8, color="black") # plot contour 0 point - #ax.scatter(contour[0,0], contour[1,0], color='red', s=100) + # ax.scatter(contour[0,0], contour[1,0], color='red', s=100) - - - - ax.set_title('Split Contours') + ax.set_title("Split Contours") # ax.set_xlabel('X') # ax.set_ylabel('Y') - + # axis off - ax.axis('off') + ax.axis("off") if SHOW: - ax.axis('equal') + ax.axis("equal") plt.show() - return get_area_from_subsegments(split_contours), split_contours def transform_to_acpc_standard(contour_ras, ac_pt_ras, pc_pt_ras): # translate AC to the origin and PC to (0, ac_pc_dist) - translation_matrix = np.array([[1, 0, -ac_pt_ras[0]], - [0, 1, -ac_pt_ras[1]], - [0, 0, 1]]) - - ac_pc_vec = pc_pt_ras - ac_pt_ras + translation_matrix = np.array([[1, 0, -ac_pt_ras[0]], [0, 1, -ac_pt_ras[1]], [0, 0, 1]]) + + ac_pc_vec = pc_pt_ras - ac_pt_ras ac_pc_dist = np.linalg.norm(ac_pc_vec) - + posterior_vector = np.array([-ac_pc_dist, 0]) - + # get angle between ac_pc_vec and posterior_vector dot_product = np.dot(ac_pc_vec, posterior_vector) norms_product = np.linalg.norm(ac_pc_vec) * np.linalg.norm(posterior_vector) theta = np.arccos(dot_product / norms_product) - + # Determine the sign of the angle using cross product cross_product = np.cross(ac_pc_vec, posterior_vector) if cross_product < 0: theta = -theta - + # create rotation matrix for theta - rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], - [np.sin(theta), np.cos(theta), 0], - [0, 0, 1]]) - + rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]]) + # apply translation and rotation if contour_ras.shape[0] == 2: contour_ras_homogeneous = np.vstack([contour_ras, np.ones(contour_ras.shape[1])]) @@ -738,17 +727,19 @@ def transform_to_acpc_standard(contour_ras, ac_pt_ras, pc_pt_ras): contour_acpc = (rotation_matrix @ translation_matrix) @ contour_ras_homogeneous contour_acpc = contour_acpc[:2, :] - - rotate_back = lambda x: (np.linalg.inv(rotation_matrix @ translation_matrix) @ np.vstack([x, np.ones(x.shape[1])]))[:2, :] + + def rotate_back(x): + return (np.linalg.inv(rotation_matrix @ translation_matrix) @ np.vstack([x, np.ones(x.shape[1])]))[:2, :] + return contour_acpc, np.array([0, 0]), np.array([-ac_pc_dist, 0]), rotate_back + def preprocess_cc(cc_label_nib, paths_csv, subj_id): cc_mask = cc_label_nib.get_fdata() == 192 - cc_mask = cc_mask[cc_mask.shape[0]//2] + cc_mask = cc_mask[cc_mask.shape[0] // 2] - - posterior_commisure_center = paths_csv.loc[subj_id, 'PC_center_r':'PC_center_s'].to_numpy().astype(float) - anterior_commisure_center = paths_csv.loc[subj_id, 'AC_center_r':'AC_center_s'].to_numpy().astype(float) + posterior_commisure_center = paths_csv.loc[subj_id, "PC_center_r":"PC_center_s"].to_numpy().astype(float) + anterior_commisure_center = paths_csv.loc[subj_id, "AC_center_r":"AC_center_s"].to_numpy().astype(float) # adjust LR from label coordinates to orig_up coordinates posterior_commisure_center[0] = 128 @@ -766,20 +757,20 @@ def get_primary_eigenvector(contour_ras): # Center the data by subtracting mean contour_mean = np.mean(contour_ras, axis=1, keepdims=True) contour_centered = contour_ras - contour_mean - + # Calculate covariance matrix cov_matrix = np.cov(contour_centered) - + # Get eigenvalues and eigenvectors using PCA eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix) - + # Sort in descending order idx = eigenvalues.argsort()[::-1] eigenvalues = eigenvalues[idx] - eigenvectors = eigenvectors[:,idx] - + eigenvectors = eigenvectors[:, idx] + # make first eigentor unit length - primary_eigenvector = eigenvectors[:,0] / np.linalg.norm(eigenvectors[:,0]) + primary_eigenvector = eigenvectors[:, 0] / np.linalg.norm(eigenvectors[:, 0]) pt0 = np.mean(contour_ras, axis=1) pt0 -= np.array([0, 5]) pt1 = pt0 + primary_eigenvector * 100 @@ -790,199 +781,6 @@ def get_primary_eigenvector(contour_ras): # # plot line between pt0 and pt1 # ax[0].plot([pt0[0], pt1[0]], [pt0[1], pt1[1]], 'r-', linewidth=2) # plt.show() - - return pt0, pt1 - -if __name__ == "__main__": - - OUTPUT_TO_RAS = False - PLOT = False - - paths_csv = pd.read_csv('/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/network/data/found_labels_with_meta_data_difficult_final.csv', index_col=0) - - - FOUND = False - for subj_id in tqdm(paths_csv.index): - - - # if subj_id != 'b1213f65' and not FOUND: - # print(subj_id, 'skipping') - # continue - - # FOUND = True - - - #subj_id = '04ac873f' - #print(subj_id) - - #subj_id = '099f7f5a' - #label_path = '099f7f5a-norm-cc_up-cropped-labels_v02-l02_merged.mgz' - label_path = paths_csv.loc[subj_id, 'label_merged'] - - try: - cc_label_nib = nib.load(label_path) - except Exception as e: - import pdb; pdb.set_trace() - print(subj_id, 'error', e) - continue - - PC_2d = paths_csv.loc[subj_id, 'PC_center_r':'PC_center_s'].to_numpy().astype(float)[1:] - AC_2d = paths_csv.loc[subj_id, 'AC_center_r':'AC_center_s'].to_numpy().astype(float)[1:] - - - cc_mask = cc_label_nib.get_fdata() == 192 - cc_mask = cc_mask[cc_mask.shape[0]//2] - - contour, anterior_endpoint_idx, posterior_endpoint_idx = get_endpoints(cc_mask, AC_2d, PC_2d, cc_label_nib.header.get_zooms()[1], return_coordinates=False) - contour_as = convert_to_ras(contour, cc_label_nib.affine) - ac_pt_as = convert_to_ras(AC_2d, cc_label_nib.affine) - pc_pt_as = convert_to_ras(PC_2d, cc_label_nib.affine) - - np.save(f'./contours/{subj_id}_contour_as.npy', contour_as) - np.save(f'./contours/{subj_id}_ac_pt_as.npy', ac_pt_as) - np.save(f'./contours/{subj_id}_pc_pt_as.npy', pc_pt_as) - continue - - #### contour to ACPC standard #### - contour_acpc, ac_pt_acpc, pc_pt_acpc, rotate_back_acpc = transform_to_acpc_standard(contour_as, ac_pt_as, pc_pt_as) - - - import matplotlib.pyplot as plt - - print(subj_id) - - # fig, ax = plt.subplots(1,1,figsize=(5, 4)) - # ax.plot(contour_acpc[0], contour_acpc[1], 'b-', label='Contour ACPC') - # ax.plot(ac_pt_acpc[0], ac_pt_acpc[1], 'gx', markersize=8, label='AC (ACPC)') - # ax.plot(pc_pt_acpc[0], pc_pt_acpc[1], 'rx', markersize=8, label='PC (ACPC)') - - # ax.plot(contour_ras[0], contour_ras[1], 'y-', label='Contour RAS') - # ax.plot(ac_pt_ras[0], ac_pt_ras[1], 'gx', markersize=8, label='AC (RAS)') - # ax.plot(pc_pt_ras[0], pc_pt_ras[1], 'rx', markersize=8, label='PC (RAS)') - - # # transform back using rotate_back_acpc - # contour_ras_back = rotate_back_acpc(contour_acpc) - # ax.plot(contour_ras_back[0], contour_ras_back[1], 'g-', label='Contour RAS back') - - # ax.set_title(subj_id) - # ax.axis('equal') - # ax.legend() - # plt.show() - # plt.close() - - - - - # fig, ax = plt.subplots(1,1,figsize=(5, 4)) - # #ax.imshow(cc_mask, cmap='gray') - - # image_path = paths_csv.loc[subj_id, 'image_orig_up'] - # image_nib = nib.load(image_path) - # image = image_nib.get_fdata() - # ax.imshow(image[127][::-1], cmap='gray')#, vmin=100, vmax=256) - # #ax.imshow(cc_mask[::-1], cmap='heat', alpha=0.5) - # #contour_acpc[:,1] = contour_acpc[:,1][::-1] - # subdivide_contour(contour_acpc, area_weights=[1/6, 1/2, 2/3, 3/4], plot=PLOT, ax=ax, plot_transform=rotate_back_acpc) - # ax.plot(contour[1], image_nib.shape[2] - contour[0], 'y-', linewidth=3) - # # invert y axis - # ax.invert_yaxis() - # plt.show() - # plt.close() - - - fig, ax = plt.subplots(2,3,figsize=(12, 8), sharex=True, sharey=True) - - - # Aboitiz scheme - subdivided_contour = subdivide_contour(contour_as, area_weights=[1/3, 2/3, 4/5], plot=PLOT, ax=ax[0,0], oriented=False, hline_anchor=ac_pt_as) - ax[0,0].set_title('Aboitiz') - - # Witelson scheme - subdivided_contour = subdivide_contour(contour_as, area_weights=[1/3, 1/2, 2/3, 4/5], plot=PLOT, ax=ax[0,1], oriented=False, hline_anchor=ac_pt_as) - ax[0,1].set_title('Witelson') - - # Jaenecke - subdivided_contour = subdivide_contour(contour_acpc, area_weights=[1/3, 1/2, 2/3, 4/5], plot=PLOT, ax=ax[0,2], oriented=True, plot_transform=rotate_back_acpc, hline_anchor=ac_pt_acpc) - ax[0,2].set_title('Jäncke') - - # Hofer-Frahm - subdivided_contour = subdivide_contour(contour_as, area_weights=[1/6, 1/2, 2/3, 3/4], plot=PLOT, ax=ax[1,0], oriented=False, hline_anchor=ac_pt_as) - ax[1,0].set_title('Hofer-Frahm') - - - - # Hofer-Frahm + Jaenecke - # subdivided_contour = subdivide_contour(contour_acpc, area_weights=[1/6, 1/2, 2/3, 3/4], plot=PLOT, ax=ax[1,1], oriented=True, plot_transform=rotate_back_acpc, hline_anchor=ac_pt_acpc) - # ax[1,1].set_title('Hofer-Frahm + Jaenecke') - - subdivided_contour = hampel_subdivide_contour(contour_as, num_rays=4, plot=PLOT, ax=ax[1,1]) - ax[1,1].set_title('Hampel') - - - - pt0, pt1 = get_primary_eigenvector(contour_as) - contour_eigen, pt0_eigen, pt1_eigen, rotate_back_eigen = transform_to_acpc_standard(contour_as, pt0, pt1) - ac_pt_eigen, _, _, _ = transform_to_acpc_standard(ac_pt_as[:, None], pt0, pt1) - ac_pt_eigen = ac_pt_eigen[:, 0] - # fig, ax = plt.subplots(1,1,figsize=(5, 4)) - # ax.plot(contour_eigen[0], contour_eigen[1], 'b-', label='Contour Eigen') - # ax.plot(pt0_eigen[0], pt0_eigen[1], 'gx', markersize=8, label='AC (Eigen)') - # ax.plot(pt1_eigen[0], pt1_eigen[1], 'rx', markersize=8, label='PC (Eigen)') - # ax.plot(contour_ras[0], contour_ras[1], 'y-', label='Contour RAS') - # ax.plot(ac_pt_ras[0], ac_pt_ras[1], 'gx', markersize=8, label='AC (RAS)') - # ax.plot(pc_pt_ras[0], pc_pt_ras[1], 'rx', markersize=8, label='PC (RAS)') - # ax.axis('equal') - # ax.legend() - # plt.show() - # plt.close() - subdivided_contour = subdivide_contour(contour_eigen, area_weights=[1/5, 2/5, 3/5, 4/5], plot=PLOT, ax=ax[1,2], oriented=True, plot_transform=rotate_back_eigen, hline_anchor=ac_pt_eigen) - ax[1,2].set_title('mri_cc') - - - - - - try: - midline_length, thickness, curvature, midline_equidistant, levelpaths = cc_thickness(contour_as.T, anterior_endpoint_idx, posterior_endpoint_idx) - except Exception as e: - contour_as += np.random.randn(contour_as.shape[0], contour_as.shape[1])*0.0001 - midline_length, thickness, curvature, midline_equidistant, levelpaths = cc_thickness(contour_as.T, anterior_endpoint_idx, posterior_endpoint_idx) - print('Successfully computed thickness after adding noise') - - #contour_as = contour_as.T - - - - - - - - plt.tight_layout() - # make axis equal - for a in ax.flatten(): - a.set_aspect('equal', adjustable='box') - a.axis('off') - - - # first two rows - # for a in ax[0:2, :].flatten(): - # a.scatter(ac_pt_as[0], ac_pt_as[1], color='red', s=100, marker='x') - # a.scatter(pc_pt_as[0], pc_pt_as[1], color='blue', s=100, marker='x') - - ax[0,0].invert_xaxis() - - plt.savefig(f'/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/cc_pipeline/subdivision_plots/cc_subdivisions_{subj_id}.png', dpi=300, bbox_inches='tight') - plt.show() - plt.close() + return pt0, pt1 - fig, ax = plt.subplots(1,1,figsize=(5, 4)) - areas, split_contours = subsegment_midline_orthogonal(midline_equidistant, [1/6, 1/2, 2/3, 3/4], contour_as, plot=True, ax=ax) - ax.invert_xaxis() - ax.set_title('Midline subdivision - Hofer-Frahm ratios') - ax.axis('equal') - ax.axis('off') - plt.savefig(f'/groups/ag-reuter/projects/corpus_callosum_fornix/pollakc/cc_pipeline/subdivision_plots/cc_subdivisions_{subj_id}_midline.png', dpi=300, bbox_inches='tight') - plt.show() - plt.close() - \ No newline at end of file diff --git a/CorpusCallosum/shape/cc_thickness.py b/CorpusCallosum/shape/cc_thickness.py index 24df8374..c64f6228 100644 --- a/CorpusCallosum/shape/cc_thickness.py +++ b/CorpusCallosum/shape/cc_thickness.py @@ -11,7 +11,7 @@ class HiddenPrints: def __enter__(self): self._original_stdout = sys.stdout - sys.stdout = open(os.devnull, 'w') + sys.stdout = open(os.devnull, "w") def __exit__(self, exc_type, exc_val, exc_tb): sys.stdout.close() @@ -21,22 +21,23 @@ def __exit__(self, exc_type, exc_val, exc_tb): def compute_curvature(path): # compute curvature by computing edge angles edges = np.diff(path, axis=0) - angles = np.arctan2(edges[:,1], edges[:,0]) + angles = np.arctan2(edges[:, 1], edges[:, 0]) # compute angle differences between consecutive edges angle_diffs = np.diff(angles) # wrap angles to [-pi, pi] - angle_diffs = np.mod(angle_diffs + np.pi, 2*np.pi) - np.pi + angle_diffs = np.mod(angle_diffs + np.pi, 2 * np.pi) - np.pi return angle_diffs def convert_to_ras(contour, vox2ras_matrix, get_parameters=False): - # converting to AS (no left-right dimension), out of plane movement is ignores, so we only do scaling, axes swapping and flipping - no rotation + # converting to AS (no left-right dimension), out of plane movement is ignores, + # so we only do scaling, axes swapping and flipping - no rotation # translation is ignored if contour.shape[0] == 2: # get only axis swaps - axis_swaps = np.round(vox2ras_matrix[:3,:3], 0) - permutation = np.argwhere(axis_swaps != 0)[:,1] - assert(len(permutation) == 3) + axis_swaps = np.round(vox2ras_matrix[:3, :3], 0) + permutation = np.argwhere(axis_swaps != 0)[:, 1] + assert len(permutation) == 3 idx_superior = np.argwhere(permutation == 2) idx_anterior = np.argwhere(permutation == 1) @@ -44,12 +45,11 @@ def convert_to_ras(contour, vox2ras_matrix, get_parameters=False): swap_axes = idx_anterior > idx_superior if swap_axes: # swap anterior and superior - contour = contour[[1,0]] - + contour = contour[[1, 0]] + # determine if axis were reversed - superior_reversed = (axis_swaps[2,:] == -1).any() - anterior_reversed = (axis_swaps[1,:] == -1).any() - + superior_reversed = (axis_swaps[2, :] == -1).any() + anterior_reversed = (axis_swaps[1, :] == -1).any() # flip axes if necessary if superior_reversed: @@ -58,7 +58,7 @@ def convert_to_ras(contour, vox2ras_matrix, get_parameters=False): contour[0] = -contour[0] # get scaling by getting length of three column vectors - scaling = np.linalg.norm(vox2ras_matrix[:3,:3], axis=0) + scaling = np.linalg.norm(vox2ras_matrix[:3, :3], axis=0) # apply transformation contour = (contour.T / scaling[1:]).T @@ -71,11 +71,11 @@ def convert_to_ras(contour, vox2ras_matrix, get_parameters=False): # # Add a third dimension (z) with 0 and a fourth dimension (homogeneous coordinate) with 1 elif contour.shape[0] == 3: contour_homogeneous = np.vstack([contour, np.ones(contour.shape[1])]) - + # Apply the transformation contour = (vox2ras_matrix @ contour_homogeneous)[:3, :] return contour - + def set_contour_zero_idx(contour, idx, anterior_endpoint_idx, posterior_endpoint_idx): contour = np.roll(contour, -idx, axis=0) @@ -86,65 +86,78 @@ def set_contour_zero_idx(contour, idx, anterior_endpoint_idx, posterior_endpoint def find_closest_edge(point, contour): """Find the index of the edge closest to the given point. - + Args: point: 2D point coordinates contour: Array of contour points (N x 2) - + Returns: Index of the closest edge """ edges_start = contour[:-1, :2] # N-1 x 2 - edges_end = contour[1:, :2] # N-1 x 2 + edges_end = contour[1:, :2] # N-1 x 2 edges_vec = edges_end - edges_start # N-1 x 2 - + # Calculate projection coefficient for all edges at once # (p-a)·(b-a) / |b-a|² edge_lengths_sq = np.sum(edges_vec * edges_vec, axis=1) # Avoid division by zero for degenerate edges valid_edges = edge_lengths_sq > 1e-10 t = np.zeros(len(edges_start)) - t[valid_edges] = np.sum((point - edges_start[valid_edges]) * edges_vec[valid_edges], axis=1) / edge_lengths_sq[valid_edges] + t[valid_edges] = ( + np.sum((point - edges_start[valid_edges]) * edges_vec[valid_edges], axis=1) + / edge_lengths_sq[valid_edges] + ) t = np.clip(t, 0, 1) # Clamp to edge endpoints - + # Get closest points on all edges - closest_points = edges_start + t[:,None] * edges_vec - + closest_points = edges_start + t[:, None] * edges_vec + # Calculate distances to all edges distances = np.linalg.norm(point - closest_points, axis=1) - + # Return index of closest edge return np.argmin(distances) -def insert_point_to_contour(contour_with_thickness, point, thickness_value, get_index=False): +def insert_point_to_contour( + contour_with_thickness, point, thickness_value, get_index=False +): """Insert a point and its thickness value into the contour. - + Args: contour_with_thickness: List containing [contour_points, thickness_values] point: 2D point to insert thickness_value: Thickness value corresponding to the point - + Returns: Updated contour_with_thickness """ # Find closest edge for the point edge_idx = find_closest_edge(point, contour_with_thickness[0]) - + # Insert point between edge endpoints - contour_with_thickness[0] = np.insert(contour_with_thickness[0], edge_idx+1, point, axis=0) - contour_with_thickness[1] = np.insert(contour_with_thickness[1], edge_idx+1, thickness_value) - + contour_with_thickness[0] = np.insert( + contour_with_thickness[0], edge_idx + 1, point, axis=0 + ) + contour_with_thickness[1] = np.insert( + contour_with_thickness[1], edge_idx + 1, thickness_value + ) + if get_index: - return contour_with_thickness, edge_idx+1 + return contour_with_thickness, edge_idx + 1 else: return contour_with_thickness def make_mesh_from_contour(contour_2d, max_volume=0.5, min_angle=25, verbose=False): - facets = np.vstack((np.arange(len(contour_2d)) , ((np.arange(len(contour_2d))+1) % len(contour_2d)))).T - + facets = np.vstack( + ( + np.arange(len(contour_2d)), + ((np.arange(len(contour_2d)) + 1) % len(contour_2d)), + ) + ).T # plot vertices and facets # import matplotlib.pyplot as plt @@ -159,22 +172,26 @@ def make_mesh_from_contour(contour_2d, max_volume=0.5, min_angle=25, verbose=Fal info.set_points(contour_2d) info.set_facets(facets) # NOTE: crashes if contour has duplicate points !! - mesh = triangle.build(info, max_volume=max_volume, min_angle=min_angle, verbose=verbose) + mesh = triangle.build( + info, max_volume=max_volume, min_angle=min_angle, verbose=verbose + ) - mesh_points = np.array(mesh.points) mesh_trias = np.array(mesh.elements) return mesh_points, mesh_trias -def cc_thickness(contour_2d, anterior_endpoint_idx, posterior_endpoint_idx, n_points=100): +def cc_thickness( + contour_2d, anterior_endpoint_idx, posterior_endpoint_idx, n_points=100 +): # standardize contour indices, to get consistent levelpath directions - contour_2d, anterior_endpoint_idx, posterior_endpoint_idx = set_contour_zero_idx(contour_2d, anterior_endpoint_idx, anterior_endpoint_idx, posterior_endpoint_idx) - - mesh_points, mesh_trias = make_mesh_from_contour(contour_2d) + contour_2d, anterior_endpoint_idx, posterior_endpoint_idx = set_contour_zero_idx( + contour_2d, anterior_endpoint_idx, anterior_endpoint_idx, posterior_endpoint_idx + ) + mesh_points, mesh_trias = make_mesh_from_contour(contour_2d) # plot mesh points with index next to point # import matplotlib.pyplot as plt @@ -184,9 +201,8 @@ def cc_thickness(contour_2d, anterior_endpoint_idx, posterior_endpoint_idx, n_po # ax.text(mesh_points[i,0], mesh_points[i,1], str(i), fontsize=7) # plt.show() - # make points 3D by appending z=0 - mesh_points3d = np.append(mesh_points,np.zeros((mesh_points.shape[0],1)),axis=1) + mesh_points3d = np.append(mesh_points, np.zeros((mesh_points.shape[0], 1)), axis=1) # compute poisson with HiddenPrints(): @@ -195,32 +211,32 @@ def cc_thickness(contour_2d, anterior_endpoint_idx, posterior_endpoint_idx, n_po bdr = np.array(tria.boundary_loops()[0]) # find index of endpoints in bdr list - iidx1=np.where(bdr==anterior_endpoint_idx)[0][0] - iidx2=np.where(bdr==posterior_endpoint_idx)[0][0] + iidx1 = np.where(bdr == anterior_endpoint_idx)[0][0] + iidx2 = np.where(bdr == posterior_endpoint_idx)[0][0] # create boundary condition (0 at endpoints, -1 on one side, 1 on the other): if iidx1 > iidx2: - tmp= iidx2 + tmp = iidx2 iidx2 = iidx1 iidx1 = tmp dcond = np.ones(bdr.shape) - dcond[iidx1] =0 - dcond[iidx2] =0 - dcond[iidx1+1:iidx2] = -1 - + dcond[iidx1] = 0 + dcond[iidx2] = 0 + dcond[iidx1 + 1 : iidx2] = -1 # Extract path with HiddenPrints(): fem = Solver(tria) - vfunc = fem.poisson(0,(bdr,dcond)) + vfunc = fem.poisson(0, (bdr, dcond)) level = 0 - midline_equidistant, midline_length = tria.level_path(vfunc, level, n_points=n_points+2) - midline_equidistant = midline_equidistant[:,:2] - + midline_equidistant, midline_length = tria.level_path( + vfunc, level, n_points=n_points + 2 + ) + midline_equidistant = midline_equidistant[:, :2] # try: with HiddenPrints(): - gf = compute_rotated_f(tria,vfunc) + gf = compute_rotated_f(tria, vfunc) # except Exception as e: # Lot contour and path # import matplotlib.pyplot as plt @@ -233,39 +249,47 @@ def cc_thickness(contour_2d, anterior_endpoint_idx, posterior_endpoint_idx, n_po # mtpltlb_tria = tri.Triangulation(tria.v[:,0], tria.v[:,1], triangles=tria.t) # ax.triplot(mtpltlb_tria, 'k-', alpha=0.2, linewidth=0.5) # # Plot final endpoint estimates - # ax.plot(contour_2d[:,0][anterior_endpoint_idx], contour_2d[:,1][anterior_endpoint_idx], 'r*', + # ax.plot(contour_2d[:,0][anterior_endpoint_idx], contour_2d[:,1][anterior_endpoint_idx], 'r*', # markersize=15, label='Final estimate') - # ax.plot(contour_2d[:,0][posterior_endpoint_idx], contour_2d[:,1][posterior_endpoint_idx], 'r*', + # ax.plot(contour_2d[:,0][posterior_endpoint_idx], contour_2d[:,1][posterior_endpoint_idx], 'r*', # markersize=15, label='Final estimate') # ax.legend() # #ax.set_title(f'Subject: {subj_id}') # plt.show() - + # interpolate midline to get levels to evaluate - gf_interp = scipy.interpolate.griddata(tria.v[:,0:2], gf, midline_equidistant[:,0:2], method='cubic') + gf_interp = scipy.interpolate.griddata( + tria.v[:, 0:2], gf, midline_equidistant[:, 0:2], method="cubic" + ) # get levels to evaluate - #level_length = tria.level_length(gf, gf_interp) + # level_length = tria.level_length(gf, gf_interp) levelpaths = [] levelpath_lengths = [] levelpath_tria_idx = [] contour_with_thickness = [contour_2d.copy(), np.full(contour_2d.shape[0], np.nan)] - for i in range(1,n_points+1): - level = gf_interp[i] + for i in range(1, n_points + 1): + level = gf_interp[i] # levelpath starts at index zero - lvlpath, lvlpath_length, tria_idx = tria.level_path(gf, level, get_tria_idx=True) - + lvlpath, lvlpath_length, tria_idx = tria.level_path( + gf, level, get_tria_idx=True + ) + levelpaths.append(lvlpath) levelpath_lengths.append(lvlpath_length) levelpath_tria_idx.append(tria_idx) - levelpath_start = lvlpath[0,:2] - levelpath_end = lvlpath[-1,:2] + levelpath_start = lvlpath[0, :2] + levelpath_end = lvlpath[-1, :2] - contour_with_thickness, inserted_idx_start = insert_point_to_contour(contour_with_thickness, levelpath_start, lvlpath_length, get_index=True) - contour_with_thickness, inserted_idx_end = insert_point_to_contour(contour_with_thickness, levelpath_end, lvlpath_length, get_index=True) + contour_with_thickness, inserted_idx_start = insert_point_to_contour( + contour_with_thickness, levelpath_start, lvlpath_length, get_index=True + ) + contour_with_thickness, inserted_idx_end = insert_point_to_contour( + contour_with_thickness, levelpath_end, lvlpath_length, get_index=True + ) # keep track of start and end indices if inserted_idx_start <= anterior_endpoint_idx: @@ -278,30 +302,28 @@ def cc_thickness(contour_2d, anterior_endpoint_idx, posterior_endpoint_idx, n_po if inserted_idx_end >= posterior_endpoint_idx: posterior_endpoint_idx += 1 - - - - - # import matplotlib.pyplot as plt # fig, ax = plt.subplots(figsize=(10, 8)) # cont = contour_with_thickness[0] # ax.plot(cont[:,0], cont[:,1], 'k-', label='Contour', marker='o', markersize=3) - # ax.scatter(cont[:,0][anterior_endpoint_idx], cont[:,1][anterior_endpoint_idx], c='r', label='Anterior Endpoint', marker='o') - # ax.scatter(cont[:,0][posterior_endpoint_idx], cont[:,1][posterior_endpoint_idx], c='b', label='Posterior Endpoint', marker='o') + # ax.scatter(cont[:,0][anterior_endpoint_idx], cont[:,1][anterior_endpoint_idx], c='r', + # label='Anterior Endpoint', marker='o') + # ax.scatter(cont[:,0][posterior_endpoint_idx], cont[:,1][posterior_endpoint_idx], c='b', + # label='Posterior Endpoint', marker='o') # ax.legend() # plt.show() - # thickness_measurement_points_top = [] - # thickness_measurement_points_bottom = [] + # thickness_measurement_points_top = [] + # thickness_measurement_points_bottom = [] # for i in range(len(levelpaths)): # thickness_measurement_points_top.append(levelpaths[i][0,:2]) # thickness_measurement_points_bottom.append(levelpaths[i][-1,:2]) # thickness_measurement_points_top = np.array(thickness_measurement_points_top) # thickness_measurement_points_bottom = np.array(thickness_measurement_points_bottom) - # thickness_measurement_points = np.concatenate([thickness_measurement_points_top, thickness_measurement_points_bottom], axis=0).T + # thickness_measurement_points = np.concatenate([thickness_measurement_points_top, + # thickness_measurement_points_bottom], axis=0).T # # Create a figure with subplots # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) @@ -317,8 +339,10 @@ def cc_thickness(contour_2d, anterior_endpoint_idx, posterior_endpoint_idx, n_po # # Plot 2: Thickness measurement points # print(thickness_measurement_points.shape) - # ax2.plot(thickness_measurement_points[0, :100], -thickness_measurement_points[1, :100], 'ro', markersize=3, label='Thickness Points (start)') - # ax2.plot(thickness_measurement_points[0, 100:], -thickness_measurement_points[1, 100:], 'go', markersize=3, label='Thickness Points (end)') + # ax2.plot(thickness_measurement_points[0, :100], -thickness_measurement_points[1, :100], 'ro', + # markersize=3, label='Thickness Points (start)') + # ax2.plot(thickness_measurement_points[0, 100:], -thickness_measurement_points[1, 100:], 'go', + # markersize=3, label='Thickness Points (end)') # ax2.set_title('Thickness Measurement Points') # ax2.set_xlabel('X') # ax2.set_ylabel('Y') @@ -326,7 +350,6 @@ def cc_thickness(contour_2d, anterior_endpoint_idx, posterior_endpoint_idx, n_po # ax2.invert_yaxis() # ax2.legend() # plt.show() - # get curvature of path3d_resampled curvature = compute_curvature(midline_equidistant) @@ -335,7 +358,6 @@ def cc_thickness(contour_2d, anterior_endpoint_idx, posterior_endpoint_idx, n_po # print(f'Length of midline: ', f'{midline_length:.2f}') # print(f'Thickness: {np.mean(levelpath_lengths):.2f}') - # import matplotlib.pyplot as plt # import matplotlib.tri as tri # fig, ax = plt.subplots(figsize=(5, 4)) @@ -343,11 +365,13 @@ def cc_thickness(contour_2d, anterior_endpoint_idx, posterior_endpoint_idx, n_po # triang = plt.tricontourf(mtpltlb_tria, gf, cmap='autumn', alpha=0.2) # ax.plot(midline_equidistant[:,0], midline_equidistant[:,1], 'r-', label=f'Levelsets')#, marker='o', markersize=2) # #ax.plot(contour_2d[:,0], contour_2d[:,1], 'k-', label='Contour', alpha=0.6) - + # for i in range(len(levelpaths)): # if levelpaths[i] is not None: - # ax.plot(levelpaths[i][:,0], levelpaths[i][:,1], 'r-', marker='o', markersize=0) # , label=f'Level {levelpath_lengths[i]:.2f}' - # ax.plot(midline_equidistant[:,0], midline_equidistant[:,1], '-', label='Midline', alpha=1, color='darkgoldenrod')#, marker='o', markersize=2) + # ax.plot(levelpaths[i][:,0], levelpaths[i][:,1], 'r-', marker='o', markersize=0) # , + # label=f'Level {levelpath_lengths[i]:.2f}' + # ax.plot(midline_equidistant[:,0], midline_equidistant[:,1], '-', label='Midline', alpha=1, + # color='darkgoldenrod')#, marker='o', markersize=2) # #plt.colorbar(colorscale, label='Level values') # # plot mesh @@ -364,5 +388,13 @@ def cc_thickness(contour_2d, anterior_endpoint_idx, posterior_endpoint_idx, n_po # plt.savefig(f'levelsets.png', dpi=300, bbox_inches='tight') # plt.show() - return midline_length, np.mean(levelpath_lengths), out_curvature, midline_equidistant, levelpaths, contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx - + return ( + midline_length, + np.mean(levelpath_lengths), + out_curvature, + midline_equidistant, + levelpaths, + contour_with_thickness, + anterior_endpoint_idx, + posterior_endpoint_idx, + ) diff --git a/CorpusCallosum/shape/resample_poly.py b/CorpusCallosum/shape/resample_poly.py deleted file mode 100644 index e9236dee..00000000 --- a/CorpusCallosum/shape/resample_poly.py +++ /dev/null @@ -1,65 +0,0 @@ -import numpy as np - -def resample_polygon(xy: np.ndarray, n_points: int = 100) -> np.ndarray: - # Cumulative Euclidean distance between successive polygon points. - # This will be the "x" for interpolation - d = np.cumsum(np.r_[0, np.sqrt((np.diff(xy, axis=0) ** 2).sum(axis=1))]) - - # get linearly spaced points along the cumulative Euclidean distance - d_sampled = np.linspace(0, d.max(), n_points) - - # interpolate x and y coordinates - xy_interp = np.c_[ - np.interp(d_sampled, d, xy[:, 0]), - np.interp(d_sampled, d, xy[:, 1]), - ] - - return xy_interp - -def iterative_resample_polygon(xy: np.ndarray, n_points: int = 100, n_iter: int = 3) -> np.ndarray: - # resample multiple times to numerically stabilize the result to be truly equidistant - xy_resampled = resample_polygon(xy, n_points) - for _ in range(n_iter-1): - xy_resampled = resample_polygon(xy_resampled, n_points) - return xy_resampled - - -if __name__ == "__main__": - import time - import matplotlib.pyplot as plt - - coords = [ - {'x': 354.0, 'y': 424.0}, {'x': 318.0, 'y': 455.0}, {'x': 299.0, 'y': 458.0}, {'x': 284.0, 'y': 464.0}, {'x': 250.0, 'y': 490.0}, - {'x': 229.0, 'y': 492.0}, {'x': 204.0, 'y': 484.0}, {'x': 187.0, 'y': 469.0}, {'x': 176.0, 'y': 449.0}, {'x': 164.0, 'y': 435.0}, - {'x': 119.0, 'y': 274.0}, {'x': 121.0, 'y': 264.0}, {'x': 118.0, 'y': 249.0}, {'x': 118.0, 'y': 224.0}, {'x': 121.0, 'y': 209.0}, - {'x': 130.0, 'y': 194.0}, {'x': 138.0, 'y': 159.0}, {'x': 147.0, 'y': 139.0}, {'x': 155.0, 'y': 112.0}, {'x': 170.0, 'y': 89.0}, - {'x': 190.0, 'y': 67.0}, {'x': 220.0, 'y': 54.0}, {'x': 280.0, 'y': 47.0}, {'x': 310.0, 'y': 55.0}, {'x': 330.0, 'y': 56.0}, - {'x': 345.0, 'y': 60.0}, {'x': 355.0, 'y': 67.0}, {'x': 367.0, 'y': 80.0}, {'x': 375.0, 'y': 84.0}, {'x': 382.0, 'y': 95.0}, - ] - - # construct numpy array from list of dicts - xy = np.array([(c['x'], c['y']) for c in coords]) - - n_points = 30 - # resample polygon - print(f"Resampling polygon with {len(xy)} points to {n_points} points") - start_time = time.time() - xy_resampled = iterative_resample_polygon(xy, n_points, n_iter=20) - end_time = time.time() - print(f"Time taken: {end_time - start_time:.2f} seconds") - - # plot result - fig, ax = plt.subplots(figsize=(7,14)) - ax.scatter(xy[:, 1], xy[:, 0], marker='o', s=150, label='original', color='black') - ax.scatter(xy_resampled[:, 1], xy_resampled[:, 0], label='resampled', color='red') - ax.set_aspect(1) - ax.invert_yaxis() - plt.legend() - plt.show() - - # Calculate distances between consecutive vertices - distances = np.sqrt(np.sum((xy_resampled[1:] - xy_resampled[:-1])**2, axis=1)) - print('Distance between consecutive vertices:', distances) - - - diff --git a/CorpusCallosum/transforms/localization_transforms.py b/CorpusCallosum/transforms/localization_transforms.py index 30891eb8..a1bf6134 100644 --- a/CorpusCallosum/transforms/localization_transforms.py +++ b/CorpusCallosum/transforms/localization_transforms.py @@ -6,7 +6,8 @@ class CropAroundACPCFixedSize(RandomizableTransform, MapTransform): Crop around AC and PC with fixed size """ - def __init__(self, keys, fixed_size: tuple[int, int], allow_missing_keys: bool = False, random_translate: float = 0) -> None: + def __init__(self, keys, fixed_size: tuple[int, int], allow_missing_keys: bool = False, + random_translate: float = 0) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self) self.random_translate = random_translate diff --git a/CorpusCallosum/transforms/segmentation_transforms.py b/CorpusCallosum/transforms/segmentation_transforms.py index a296e942..1158d488 100644 --- a/CorpusCallosum/transforms/segmentation_transforms.py +++ b/CorpusCallosum/transforms/segmentation_transforms.py @@ -7,7 +7,8 @@ class CropAroundACPC(RandomizableTransform, MapTransform): Crop around AC and PC """ - def __init__(self, keys, allow_missing_keys: bool = False, padding_mm: float = 10, random_translate: float = 0) -> None: + def __init__(self, keys, allow_missing_keys: bool = False, padding_mm: float = 10, + random_translate: float = 0) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob=1, do_transform=True) self.padding_mm = padding_mm @@ -36,8 +37,10 @@ def __call__(self, data): # 'PC_center': array([ 2., 139., 143.], dtype=float32), 'AC_center': array([ 2., 128., 168.] - ac_pc_bottomleft = (np.min([ac_center[1], pc_center[1]]).astype(int), np.min([ac_center[2], pc_center[2]]).astype(int)) - ac_pc_topright = (np.max([ac_center[1], pc_center[1]]).astype(int), np.max([ac_center[2], pc_center[2]]).astype(int)) + ac_pc_bottomleft = (np.min([ac_center[1], pc_center[1]]).astype(int), + np.min([ac_center[2], pc_center[2]]).astype(int)) + ac_pc_topright = (np.max([ac_center[1], pc_center[1]]).astype(int), + np.max([ac_center[2], pc_center[2]]).astype(int)) voxel_padding = round(self.padding_mm / d['res']) @@ -85,21 +88,3 @@ def __call__(self, data): return d - - -class UncropAroundACPC(MapTransform): - """ - Uncrop around AC and PC - reverses CropAroundACPC transform by padding back to original size - """ - - def __init__(self, keys, allow_missing_keys: bool = False, padding_mm: float = 10) -> None: - super().__init__(keys, allow_missing_keys) - self.padding_mm = padding_mm - - def __call__(self, data): - pad_left, pad_right, pad_top, pad_bottom = d['to_pad'] - - # Pad back to original size - d[key] = np.pad(d[key], ((0,0), (0,0), (pad_left.item(), pad_right.item()), (pad_top.item(), pad_bottom.item())), mode='constant', constant_values=0) - - return d \ No newline at end of file diff --git a/CorpusCallosum/visualization/visualization.py b/CorpusCallosum/visualization/visualization.py index ae4f5fde..dcfab1da 100644 --- a/CorpusCallosum/visualization/visualization.py +++ b/CorpusCallosum/visualization/visualization.py @@ -1,49 +1,50 @@ from pathlib import Path import numpy as np import matplotlib.pyplot as plt -from typing import Tuple, List, Union -import nibabel as nib -from scipy.ndimage import affine_transform - -#from mapping_helpers import apply_transform_and_map_volume - def plot_standardized_space(ax_row, vol, ac_coords, pc_coords): """Plot standardized space visualization across three views. - + Args: ax_row: Row of axes to plot on (should be length 3) vol: Volume data to visualize ac_coords: AC coordinates in standardized space pc_coords: PC coordinates in standardized space """ - ax_row[0].set_title('Standardized') - + ax_row[0].set_title("Standardized") + # Axial view - ax_row[0].scatter(ac_coords[2], ac_coords[1], color='red', marker='x') - ax_row[0].scatter(pc_coords[2], pc_coords[1], color='blue', marker='x') - ax_row[0].imshow(vol[vol.shape[0]//2], cmap='gray') + ax_row[0].scatter(ac_coords[2], ac_coords[1], color="red", marker="x") + ax_row[0].scatter(pc_coords[2], pc_coords[1], color="blue", marker="x") + ax_row[0].imshow(vol[vol.shape[0] // 2], cmap="gray") # Sagittal view - ax_row[1].scatter(ac_coords[2], ac_coords[0], color='red', marker='x') - ax_row[1].scatter(pc_coords[2], pc_coords[0], color='blue', marker='x') - ax_row[1].imshow(vol[:,vol.shape[1]//2], cmap='gray') + ax_row[1].scatter(ac_coords[2], ac_coords[0], color="red", marker="x") + ax_row[1].scatter(pc_coords[2], pc_coords[0], color="blue", marker="x") + ax_row[1].imshow(vol[:, vol.shape[1] // 2], cmap="gray") # Coronal view - ax_row[2].scatter(ac_coords[1], ac_coords[0], color='red', marker='x') - ax_row[2].scatter(pc_coords[1], pc_coords[0], color='blue', marker='x') - ax_row[2].imshow(vol[:,:,vol.shape[2]//2], cmap='gray') - - -def visualize_coordinate_spaces(orig, upright, standardized, - ac_coords_orig, pc_coords_orig, - ac_coords_3d, pc_coords_3d, - ac_coords_standardized, pc_coords_standardized, - output_dir): + ax_row[2].scatter(ac_coords[1], ac_coords[0], color="red", marker="x") + ax_row[2].scatter(pc_coords[1], pc_coords[0], color="blue", marker="x") + ax_row[2].imshow(vol[:, :, vol.shape[2] // 2], cmap="gray") + + +def visualize_coordinate_spaces( + orig, + upright, + standardized, + ac_coords_orig, + pc_coords_orig, + ac_coords_3d, + pc_coords_3d, + ac_coords_standardized, + pc_coords_standardized, + output_dir, +): """ Visualize the AC and PC coordinates in different coordinate spaces for testing/debugging. - + Args: orig: Original image volume vol: Volume in fsaverage space @@ -58,102 +59,44 @@ def visualize_coordinate_spaces(orig, upright, standardized, # Original space - using plot_standardized_space plot_standardized_space(ax[0], orig.get_fdata(), ac_coords_orig, pc_coords_orig) - ax[0,0].set_title('Orig') + ax[0, 0].set_title("Orig") # Fsaverage space - plot_standardized_space(ax[1], upright, ac_coords_3d, pc_coords_3d) - ax[1,0].set_title('Fsaverage') + plot_standardized_space(ax[1], upright, ac_coords_3d, pc_coords_3d) + ax[1, 0].set_title("Fsaverage") # Standardized space plot_standardized_space(ax[2], standardized, ac_coords_standardized, pc_coords_standardized) - ax[2,0].set_title('Standardized') + ax[2, 0].set_title("Standardized") # Format all subplots for a in ax.flatten(): - a.set_aspect('equal', adjustable='box') - a.axis('off') + a.set_aspect("equal", adjustable="box") + a.axis("off") - plt.savefig(Path(output_dir) / "ac_pc_spaces.png", dpi=300, bbox_inches='tight') + plt.savefig(Path(output_dir) / "ac_pc_spaces.png", dpi=300, bbox_inches="tight") plt.show() plt.close() -# def map_image_to_standard_space(orig: nib.MGHImage, -# ac_coords_3d: np.ndarray, -# pc_coords_3d: np.ndarray, -# orig_fsaverage_vox2vox: np.ndarray, -# output_dir: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: -# """Maps an input image to standard space using AC-PC alignment. - -# This function performs the following transformations: -# 1. Maps the image from original space to fsaverage space -# 2. Applies nodding correction -# 3. Translates AC point to center - -# Args: -# orig: Original input image as MGHImage -# ac_coords_3d: Anterior commissure coordinates in 3D -# pc_coords_3d: Posterior commissure coordinates in 3D -# orig_fsaverage_vox2vox: Transform matrix from original to fsaverage space -# output_dir: Directory to save intermediate visualization volumes - -# Returns: -# Tuple containing: -# - orig_to_standardized_vox2vox: Final transformation matrix -# - ac_coords_standardized: AC coordinates in standardized space -# - pc_coords_standardized: PC coordinates in standardized space -# - ac_coords_orig: Original AC coordinates -# - pc_coords_orig: Original PC coordinates -# """ -# # ... existing code ... - -# # Generate intermediate volumes for visualization -# vol = apply_transform_and_map_volume( -# orig.get_fdata(), -# orig_fsaverage_vox2vox, -# orig.affine, -# orig.header, -# Path(output_dir) / "inv_fsaverage_orig_vox2vox.mgz" -# ) - -# vol3 = apply_transform_and_map_volume( -# vol2, -# ac_to_center_translation, -# orig.affine, -# orig.header, -# Path(output_dir) / "translation.mgz" -# ) - -# # Visualize coordinate spaces -# visualize_coordinate_spaces( -# orig, vol, vol2, vol3, -# ac_coords_orig, pc_coords_orig, -# ac_coords_3d, pc_coords_3d, -# ac_coords_standardized, pc_coords_standardized, -# output_dir -# ) - -# return orig_to_standardized_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig - - - - -def plot_contours(transformed: np.ndarray, - split_contours: List[np.ndarray], - split_contours_hofer_frahm: List[np.ndarray], - midline_equidistant: np.ndarray, - levelpaths: List[np.ndarray], - output_path: str, - ac_coords: np.ndarray, - pc_coords: np.ndarray, - vox_size: float, - title: str = None) -> None: +def plot_contours( + transformed: np.ndarray, + split_contours: list[np.ndarray], + split_contours_hofer_frahm: list[np.ndarray], + midline_equidistant: np.ndarray, + levelpaths: list[np.ndarray], + output_path: str, + ac_coords: np.ndarray, + pc_coords: np.ndarray, + vox_size: float, + title: str = None, +) -> None: """Plots corpus callosum contours and segmentations. - + Creates a figure with three subplots showing: 1. Midline-based subsegmentation 2. Hofer-Frahm segmentation scheme 3. Midline and levelpaths visualization - + Args: transformed: The transformed brain image array split_contours: List of contour arrays for midline-based segmentation @@ -166,111 +109,118 @@ def plot_contours(transformed: np.ndarray, """ # scale contour data by vox_size - split_contours = [split_contour * vox_size for split_contour in split_contours] if split_contours is not None else None - split_contours_hofer_frahm = [split_contour * vox_size for split_contour in split_contours_hofer_frahm] if split_contours_hofer_frahm is not None else None + split_contours = ( + [split_contour * vox_size for split_contour in split_contours] if split_contours is not None else None + ) + split_contours_hofer_frahm = ( + [split_contour * vox_size for split_contour in split_contours_hofer_frahm] + if split_contours_hofer_frahm is not None + else None + ) midline_equidistant = midline_equidistant * vox_size levelpaths = [levelpath * vox_size for levelpath in levelpaths] - - NO_PLOTS = 1 if split_contours is not None: NO_PLOTS += 1 if split_contours_hofer_frahm is not None: NO_PLOTS += 1 - fig, ax = plt.subplots(1,NO_PLOTS, sharex=True, sharey=True, figsize=(15, 10)) + fig, ax = plt.subplots(1, NO_PLOTS, sharex=True, sharey=True, figsize=(15, 10)) PLT_NUM = 0 if split_contours is not None: - ax[PLT_NUM].imshow(transformed[transformed.shape[0]//2], cmap='gray') - #ax[0].imshow(cc_mask, cmap='autumn') + ax[PLT_NUM].imshow(transformed[transformed.shape[0] // 2], cmap="gray") + # ax[0].imshow(cc_mask, cmap='autumn') ax[PLT_NUM].set_title(title) for i in range(len(split_contours)): - ax[PLT_NUM].fill(split_contours[i][0,:], -split_contours[i][1,:], color='steelblue', alpha=0.25) - ax[PLT_NUM].plot(split_contours[i][0,:], -split_contours[i][1,:], color='mediumblue', linestyle='dotted', linewidth=0.7) - ax[PLT_NUM].plot(split_contours[0][0,:], -split_contours[0][1,:], color='mediumblue', linewidth=0.7) - ax[PLT_NUM].scatter(ac_coords[1], ac_coords[0], color='red', marker='x') - ax[PLT_NUM].scatter(pc_coords[1], pc_coords[0], color='blue', marker='x') + ax[PLT_NUM].fill(split_contours[i][0, :], -split_contours[i][1, :], color="steelblue", alpha=0.25) + ax[PLT_NUM].plot( + split_contours[i][0, :], -split_contours[i][1, :], color="mediumblue", linestyle="dotted", linewidth=0.7 + ) + ax[PLT_NUM].plot(split_contours[0][0, :], -split_contours[0][1, :], color="mediumblue", linewidth=0.7) + ax[PLT_NUM].scatter(ac_coords[1], ac_coords[0], color="red", marker="x") + ax[PLT_NUM].scatter(pc_coords[1], pc_coords[0], color="blue", marker="x") PLT_NUM += 1 if split_contours_hofer_frahm is not None: - - ax[PLT_NUM].imshow(transformed[transformed.shape[0]//2], cmap='gray') - #ax[1].imshow(cc_mask, cmap='autumn') - ax[PLT_NUM].set_title('Hofer-Frahm Jaenecke') + ax[PLT_NUM].imshow(transformed[transformed.shape[0] // 2], cmap="gray") + # ax[1].imshow(cc_mask, cmap='autumn') + ax[PLT_NUM].set_title("Hofer-Frahm Jaenecke") for i in range(len(split_contours_hofer_frahm)): - ax[PLT_NUM].fill(split_contours_hofer_frahm[i][0,:], -split_contours_hofer_frahm[i][1,:], color='steelblue', alpha=0.25) - ax[PLT_NUM].plot([split_contours_hofer_frahm[i][0,0], split_contours_hofer_frahm[i][0,-1]], [-split_contours_hofer_frahm[i][1,0], -split_contours_hofer_frahm[i][1,-1]], color='mediumblue', linestyle='dotted', linewidth=0.7) - ax[PLT_NUM].plot(split_contours_hofer_frahm[0][0,:], -split_contours_hofer_frahm[0][1,:], color='mediumblue', linewidth=0.7) - ax[PLT_NUM].scatter(ac_coords[1], ac_coords[0], color='red', marker='x') - ax[PLT_NUM].scatter(pc_coords[1], pc_coords[0], color='blue', marker='x') + ax[PLT_NUM].fill( + split_contours_hofer_frahm[i][0, :], -split_contours_hofer_frahm[i][1, :], color="steelblue", alpha=0.25 + ) + ax[PLT_NUM].plot( + [split_contours_hofer_frahm[i][0, 0], split_contours_hofer_frahm[i][0, -1]], + [-split_contours_hofer_frahm[i][1, 0], -split_contours_hofer_frahm[i][1, -1]], + color="mediumblue", + linestyle="dotted", + linewidth=0.7, + ) + ax[PLT_NUM].plot( + split_contours_hofer_frahm[0][0, :], -split_contours_hofer_frahm[0][1, :], color="mediumblue", linewidth=0.7 + ) + ax[PLT_NUM].scatter(ac_coords[1], ac_coords[0], color="red", marker="x") + ax[PLT_NUM].scatter(pc_coords[1], pc_coords[0], color="blue", marker="x") PLT_NUM += 1 reference_contour = split_contours[0] if split_contours is not None else split_contours_hofer_frahm[0] - ax[PLT_NUM].imshow(transformed[transformed.shape[0]//2], cmap='gray') - #ax[2].imshow(cc_mask, cmap='autumn') + ax[PLT_NUM].imshow(transformed[transformed.shape[0] // 2], cmap="gray") + # ax[2].imshow(cc_mask, cmap='autumn') for i in range(len(levelpaths)): - ax[PLT_NUM].plot(levelpaths[i][:,0], -levelpaths[i][:,1], color='brown', linewidth=0.8) - ax[PLT_NUM].set_title('Midline & Levelpaths') - ax[PLT_NUM].plot(midline_equidistant[:,0], -midline_equidistant[:,1], color='red') - ax[PLT_NUM].plot(reference_contour[0,:], -reference_contour[1,:], color='red', linewidth=0.5) + ax[PLT_NUM].plot(levelpaths[i][:, 0], -levelpaths[i][:, 1], color="brown", linewidth=0.8) + ax[PLT_NUM].set_title("Midline & Levelpaths") + ax[PLT_NUM].plot(midline_equidistant[:, 0], -midline_equidistant[:, 1], color="red") + ax[PLT_NUM].plot(reference_contour[0, :], -reference_contour[1, :], color="red", linewidth=0.5) for a in ax.flatten(): - a.set_aspect('equal', adjustable='box') - a.axis('off') + a.set_aspect("equal", adjustable="box") + a.axis("off") # get bounding box of countours padding = 30 - ax[0].set_xlim(reference_contour[0,:].min()-padding, reference_contour[0,:].max()+padding) - ax[0].set_ylim((-reference_contour[1,:]).max()+padding, (-reference_contour[1,:]).min()-padding) - - - plt.savefig(output_path, dpi=300, bbox_inches='tight') - #plt.show() + ax[0].set_xlim(reference_contour[0, :].min() - padding, reference_contour[0, :].max() + padding) + ax[0].set_ylim((-reference_contour[1, :]).max() + padding, (-reference_contour[1, :]).min() - padding) + plt.savefig(output_path, dpi=300, bbox_inches="tight") + # plt.show() def plot_midplane(grid_orig, orig): """ Creates a 3D visualization of grid points in original image space. - + Args: grid_orig: Grid points in original space orig: Original image for dimension reference """ # Create a figure showing grid points in original space - + # Create 3D plot fig = plt.figure(figsize=(10, 10)) - ax = fig.add_subplot(111, projection='3d') - + ax = fig.add_subplot(111, projection="3d") + # Plot every 10th point to avoid overcrowding sample_idx = np.arange(0, grid_orig.shape[1], 40) ax.scatter( - grid_orig[0,sample_idx], - grid_orig[1,sample_idx], - grid_orig[2,sample_idx], - c='r', - alpha=0.1, - marker='.' + grid_orig[0, sample_idx], grid_orig[1, sample_idx], grid_orig[2, sample_idx], c="r", alpha=0.1, marker="." ) - + # Set labels - ax.set_xlabel('X') - ax.set_ylabel('Y') - ax.set_zlabel('Z') - ax.set_title('Grid Points in Original Image Space') - + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_zlabel("Z") + ax.set_title("Grid Points in Original Image Space") + # Set axis limits to image dimensions ax.set_xlim(0, orig.shape[0]) ax.set_ylim(0, orig.shape[1]) ax.set_zlim(0, orig.shape[2]) - + # Save plot plt.show() # plt.savefig('grid_points.png') # plt.close() - From d5ab726318b5a46a0f4727dc983699da279b6ca4 Mon Sep 17 00:00:00 2001 From: ClePol Date: Fri, 19 Sep 2025 15:53:57 +0200 Subject: [PATCH 05/68] formatting and requirements fixes --- CorpusCallosum/cc_visualization.py | 1 + .../data/generate_fsaverage_centroids.py | 7 ++++--- CorpusCallosum/data/read_write.py | 3 ++- CorpusCallosum/fastsurfer_cc.py | 15 ++++++------- .../localization/localization_inference.py | 3 +-- .../registration/mapping_helpers.py | 2 +- .../segmentation/segmentation_inference.py | 7 +++---- .../segmentation_postprocessing.py | 1 - CorpusCallosum/shape/cc_endpoint_heuristic.py | 4 ++-- CorpusCallosum/shape/cc_mesh.py | 11 +++++----- CorpusCallosum/shape/cc_postprocessing.py | 21 +++++++++---------- CorpusCallosum/shape/cc_thickness.py | 8 +++---- .../transforms/localization_transforms.py | 3 ++- .../transforms/segmentation_transforms.py | 2 +- CorpusCallosum/visualization/visualization.py | 3 ++- 15 files changed, 46 insertions(+), 45 deletions(-) diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index def10b62..65abf837 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -1,5 +1,6 @@ import argparse from pathlib import Path + import numpy as np from CorpusCallosum.data.fsaverage_cc_template import load_fsaverage_cc_template diff --git a/CorpusCallosum/data/generate_fsaverage_centroids.py b/CorpusCallosum/data/generate_fsaverage_centroids.py index f97a1777..40a97095 100644 --- a/CorpusCallosum/data/generate_fsaverage_centroids.py +++ b/CorpusCallosum/data/generate_fsaverage_centroids.py @@ -7,12 +7,13 @@ Run this script once to generate the centroids file. """ -import os import json +import os from pathlib import Path -import numpy as np + import nibabel as nib -from read_write import get_centroids_from_nib, convert_numpy_to_json_serializable +import numpy as np +from read_write import convert_numpy_to_json_serializable, get_centroids_from_nib def main(): diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py index 16947de7..5a07f605 100644 --- a/CorpusCallosum/data/read_write.py +++ b/CorpusCallosum/data/read_write.py @@ -1,6 +1,7 @@ import multiprocessing -import numpy as np + import nibabel as nib +import numpy as np def run_in_background(function, debug=False, *args, **kwargs): diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 14171727..e5064924 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -1,11 +1,8 @@ import argparse import json -import warnings - -warnings.filterwarnings("ignore", message="TypedStorage is deprecated") - from pathlib import Path +# import warnings warnings.filterwarnings("ignore", message="TypedStorage is deprecated") import nibabel as nib import numpy as np import torch @@ -16,12 +13,12 @@ from FastSurferCNN.data_loader.conform import is_conform from CorpusCallosum.data.constants import ( + CC_LABEL, FSAVERAGE_CENTROIDS_PATH, FSAVERAGE_DATA_PATH, - WEIGHTS_PATH, - STANDARD_OUTPUT_PATHS, FSAVERAGE_MIDDLE, - CC_LABEL, + STANDARD_OUTPUT_PATHS, + WEIGHTS_PATH, ) from CorpusCallosum.data.read_write import ( convert_numpy_to_json_serializable, @@ -31,6 +28,8 @@ run_in_background, save_nifti_background, ) +from CorpusCallosum.localization import localization_inference +from CorpusCallosum.registration import find_rigid, lta from CorpusCallosum.registration.mapping_helpers import ( apply_transform_and_map_volume, apply_transform_to_pt, @@ -38,7 +37,9 @@ interpolate_midplane, map_softlabels_to_orig, ) +from CorpusCallosum.segmentation import segmentation_inference, segmentation_postprocessing from CorpusCallosum.shape.cc_postprocessing import create_visualization, process_slices +from FastSurferCNN.data_loader.conform import is_conform def options_parse() -> argparse.Namespace: diff --git a/CorpusCallosum/localization/localization_inference.py b/CorpusCallosum/localization/localization_inference.py index 9055c6d4..a60c28c2 100644 --- a/CorpusCallosum/localization/localization_inference.py +++ b/CorpusCallosum/localization/localization_inference.py @@ -1,8 +1,7 @@ -import torch import numpy as np +import torch from monai import transforms from monai.networks.nets import DenseNet as DenseNet_monai - from transforms.localization_transforms import CropAroundACPCFixedSize diff --git a/CorpusCallosum/registration/mapping_helpers.py b/CorpusCallosum/registration/mapping_helpers.py index 80312a4d..3e4a0d1d 100644 --- a/CorpusCallosum/registration/mapping_helpers.py +++ b/CorpusCallosum/registration/mapping_helpers.py @@ -1,5 +1,5 @@ -import numpy as np import nibabel as nib +import numpy as np from scipy.ndimage import affine_transform diff --git a/CorpusCallosum/segmentation/segmentation_inference.py b/CorpusCallosum/segmentation/segmentation_inference.py index 22113569..154bd9d6 100644 --- a/CorpusCallosum/segmentation/segmentation_inference.py +++ b/CorpusCallosum/segmentation/segmentation_inference.py @@ -1,11 +1,10 @@ -import torch -import numpy as np import nibabel as nib - +import numpy as np +import torch from monai import transforms +from transforms.segmentation_transforms import CropAroundACPC from FastSurferCNN.models.networks import FastSurferVINN -from transforms.segmentation_transforms import CropAroundACPC def load_model(checkpoint_path, device=None): diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index 5ffcfde5..dfba0e0e 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -5,7 +5,6 @@ from CorpusCallosum.data.constants import CC_LABEL, FORNIX_LABEL - def get_cc_volume(desired_width_mm: int, cc_mask: np.ndarray, voxel_size: tuple[float, float, float]) -> float: """Calculate the volume of the corpus callosum in cubic millimeters. diff --git a/CorpusCallosum/shape/cc_endpoint_heuristic.py b/CorpusCallosum/shape/cc_endpoint_heuristic.py index 2186e581..aafd0af6 100644 --- a/CorpusCallosum/shape/cc_endpoint_heuristic.py +++ b/CorpusCallosum/shape/cc_endpoint_heuristic.py @@ -1,7 +1,7 @@ +import lapy import numpy as np -import skimage.measure import scipy.ndimage -import lapy +import skimage.measure def get_endpoints(cc_mask, AC_2d, PC_2d, resolution, return_coordinates=True, contour_smoothing=1.0): diff --git a/CorpusCallosum/shape/cc_mesh.py b/CorpusCallosum/shape/cc_mesh.py index 1c5e9a26..72399622 100644 --- a/CorpusCallosum/shape/cc_mesh.py +++ b/CorpusCallosum/shape/cc_mesh.py @@ -1,17 +1,16 @@ import tempfile -import numpy as np +import lapy import matplotlib import matplotlib.pyplot as plt -import plotly.graph_objects as go import nibabel as nib -import lapy +import numpy as np +import plotly.graph_objects as go import pyrr import scipy.interpolate from scipy.ndimage import gaussian_filter1d - +from shape.cc_thickness import HiddenPrints, make_mesh_from_contour from whippersnappy.core import snap1 -from shape.cc_thickness import make_mesh_from_contour, HiddenPrints class CC_Mesh(lapy.TriaMesh): @@ -388,9 +387,9 @@ def plot_mesh( fig.write_html(output_path) # Save as interactive HTML else: # For non-interactive display, save to a temporary HTML and open in browser + import os import tempfile import webbrowser - import os temp_path = os.path.join(tempfile.gettempdir(), "cc_mesh_plot.html") fig.write_html(temp_path) diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/cc_postprocessing.py index 7063b01b..3d4bcc31 100644 --- a/CorpusCallosum/shape/cc_postprocessing.py +++ b/CorpusCallosum/shape/cc_postprocessing.py @@ -1,22 +1,21 @@ from pathlib import Path import numpy as np - -from shape.cc_thickness import convert_to_ras, cc_thickness +from shape.cc_endpoint_heuristic import get_endpoints +from shape.cc_mesh import CC_Mesh +from shape.cc_metrics import calculate_cc_index from shape.cc_subsegment_contour import ( + get_primary_eigenvector, + hampel_subdivide_contour, subdivide_contour, - transform_to_acpc_standard, subsegment_midline_orthogonal, - hampel_subdivide_contour, + transform_to_acpc_standard, ) -from shape.cc_endpoint_heuristic import get_endpoints -from shape.cc_metrics import calculate_cc_index -from shape.cc_subsegment_contour import get_primary_eigenvector -from shape.cc_mesh import CC_Mesh -from CorpusCallosum.visualization.visualization import plot_contours -from CorpusCallosum.data.read_write import run_in_background -from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE +from shape.cc_thickness import cc_thickness, convert_to_ras +from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE +from CorpusCallosum.data.read_write import run_in_background +from CorpusCallosum.visualization.visualization import plot_contours # assert LIA orientation LIA_ORIENTATION = np.zeros((3,3)) diff --git a/CorpusCallosum/shape/cc_thickness.py b/CorpusCallosum/shape/cc_thickness.py index c64f6228..3d5be36d 100644 --- a/CorpusCallosum/shape/cc_thickness.py +++ b/CorpusCallosum/shape/cc_thickness.py @@ -1,11 +1,11 @@ -import sys import os +import sys -import numpy as np -from lapy import TriaMesh, Solver -from lapy.diffgeo import compute_rotated_f import meshpy.triangle as triangle +import numpy as np import scipy.interpolate +from lapy import Solver, TriaMesh +from lapy.diffgeo import compute_rotated_f class HiddenPrints: diff --git a/CorpusCallosum/transforms/localization_transforms.py b/CorpusCallosum/transforms/localization_transforms.py index a1bf6134..215e4817 100644 --- a/CorpusCallosum/transforms/localization_transforms.py +++ b/CorpusCallosum/transforms/localization_transforms.py @@ -1,5 +1,6 @@ -from monai.transforms import RandomizableTransform, MapTransform import numpy as np +from monai.transforms import MapTransform, RandomizableTransform + class CropAroundACPCFixedSize(RandomizableTransform, MapTransform): """ diff --git a/CorpusCallosum/transforms/segmentation_transforms.py b/CorpusCallosum/transforms/segmentation_transforms.py index 1158d488..ad7d7484 100644 --- a/CorpusCallosum/transforms/segmentation_transforms.py +++ b/CorpusCallosum/transforms/segmentation_transforms.py @@ -1,5 +1,5 @@ -from monai.transforms import RandomizableTransform, MapTransform import numpy as np +from monai.transforms import MapTransform, RandomizableTransform class CropAroundACPC(RandomizableTransform, MapTransform): diff --git a/CorpusCallosum/visualization/visualization.py b/CorpusCallosum/visualization/visualization.py index dcfab1da..beeafb51 100644 --- a/CorpusCallosum/visualization/visualization.py +++ b/CorpusCallosum/visualization/visualization.py @@ -1,6 +1,7 @@ from pathlib import Path -import numpy as np + import matplotlib.pyplot as plt +import numpy as np def plot_standardized_space(ax_row, vol, ac_coords, pc_coords): From 936b9f4aee08476e200fa9148d949e5663847b71 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Fri, 19 Sep 2025 12:23:38 +0200 Subject: [PATCH 06/68] fix typo in comment --- CorpusCallosum/visualization/visualization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CorpusCallosum/visualization/visualization.py b/CorpusCallosum/visualization/visualization.py index beeafb51..1207c548 100644 --- a/CorpusCallosum/visualization/visualization.py +++ b/CorpusCallosum/visualization/visualization.py @@ -181,7 +181,7 @@ def plot_contours( a.set_aspect("equal", adjustable="box") a.axis("off") - # get bounding box of countours + # get bounding box of contours padding = 30 ax[0].set_xlim(reference_contour[0, :].min() - padding, reference_contour[0, :].max() + padding) ax[0].set_ylim((-reference_contour[1, :]).max() + padding, (-reference_contour[1, :]).min() - padding) From 6de7538d954661c301ba7521e0acaeb964739575 Mon Sep 17 00:00:00 2001 From: ClePol Date: Fri, 19 Sep 2025 15:56:14 +0200 Subject: [PATCH 07/68] fixed spelling in comments --- CorpusCallosum/shape/cc_subsegment_contour.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/CorpusCallosum/shape/cc_subsegment_contour.py b/CorpusCallosum/shape/cc_subsegment_contour.py index 977fa353..0ba34e33 100644 --- a/CorpusCallosum/shape/cc_subsegment_contour.py +++ b/CorpusCallosum/shape/cc_subsegment_contour.py @@ -198,7 +198,7 @@ def subsegment_midline_orthogonal(midline, area_weights, contour, plot=True, ax= # plot contour ax.plot(contour[0], contour[1], "-", linewidth=2, color="grey") # put text between split points - # add enpoints to split_points + # add endpoints to split_points split_points = split_points.tolist() split_points.insert(0, extremes[0]) split_points.append(extremes[1]) @@ -385,7 +385,7 @@ def hampel_subdivide_contour(contour, num_rays, plot=False, ax=None): [midpoint_lower_edge[1], midpoint_lower_edge[1] + ray_vector[1]], "k--", ) - # pretty plot with areas filles in the polygon and overall area annotated + # pretty plot with areas files in the polygon and overall area annotated colors = plt.cm.Spectral(np.linspace(0.2, 0.8, len(split_contours))) for color, split_contour in zip(colors, split_contours, strict=True): ax.fill(split_contour[0], split_contour[1], alpha=0.5, color=color) @@ -425,7 +425,7 @@ def subdivide_contour( # extremes = (np.array([extremes[0][0], most_inferior_point - 5]), # np.array([extremes[1][0], most_inferior_point - 5])) # else: - # # get y diffrence between extremes and hline_anchor + # # get y difference between extremes and hline_anchor # y_diff = extremes[1][1] - hline_anchor[1] # extremes = (np.array([extremes[0][0], extremes[0][1] - y_diff]), # np.array([extremes[1][0], extremes[1][1] - y_diff])) @@ -618,7 +618,7 @@ def subdivide_contour( ax.axis("equal") else: SHOW = False - # pretty plot with areas filles in the polygon and overall area annotated + # pretty plot with areas filled in the polygon and overall area annotated colors = plt.cm.Spectral(np.linspace(0.2, 0.8, len(split_contours))) for color, split_contour in zip(colors, split_contours, strict=True): ax.fill(split_contour[0], split_contour[1], alpha=0.5, color=color) @@ -646,7 +646,7 @@ def subdivide_contour( ax.plot(start_point_vline[:, 0], start_point_vline[:, 1], "--", linewidth=2, color="grey") ax.plot(end_point_vline[:, 0], end_point_vline[:, 1], "--", linewidth=2, color="grey") # put text between split points - # add enpoints to split_points + # add endpoints to split_points split_points.insert(0, extremes[0]) split_points.append(extremes[1]) # convert area_weights into fraction of total line length @@ -769,7 +769,7 @@ def get_primary_eigenvector(contour_ras): eigenvalues = eigenvalues[idx] eigenvectors = eigenvectors[:, idx] - # make first eigentor unit length + # make first eigenvector unit length primary_eigenvector = eigenvectors[:, 0] / np.linalg.norm(eigenvectors[:, 0]) pt0 = np.mean(contour_ras, axis=1) pt0 -= np.array([0, 5]) From 8bb88339fc6253b8abca72f77bf99ec80eded31e Mon Sep 17 00:00:00 2001 From: ClePol Date: Fri, 19 Sep 2025 18:09:58 +0200 Subject: [PATCH 08/68] added checkpoint donwloading --- CorpusCallosum/config/checkpoint_paths.yaml | 7 ++++ CorpusCallosum/data/constants.py | 5 +-- CorpusCallosum/fastsurfer_cc.py | 24 ++++++----- .../localization/localization_inference.py | 22 ++++++++-- .../segmentation/segmentation_inference.py | 12 +++++- CorpusCallosum/shape/cc_endpoint_heuristic.py | 9 +++- CorpusCallosum/shape/cc_mesh.py | 33 +++++++++------ CorpusCallosum/shape/cc_postprocessing.py | 8 ++-- CorpusCallosum/utils/checkpoint.py | 17 ++++++++ CorpusCallosum/visualization/visualization.py | 2 + FastSurferCNN/download_checkpoints.py | 42 ++++++++++++++----- doc/api/recon_surf.rst | 1 - 12 files changed, 137 insertions(+), 45 deletions(-) create mode 100644 CorpusCallosum/config/checkpoint_paths.yaml create mode 100644 CorpusCallosum/utils/checkpoint.py diff --git a/CorpusCallosum/config/checkpoint_paths.yaml b/CorpusCallosum/config/checkpoint_paths.yaml new file mode 100644 index 00000000..ca78b7da --- /dev/null +++ b/CorpusCallosum/config/checkpoint_paths.yaml @@ -0,0 +1,7 @@ +url: +- "https://zenodo.org/records/17141933/files" +- "https://b2share.fz-juelich.de/api/files/e4eb699c-ba68-4470-9f3d-89ceeee1a334" + +checkpoint: + segmentation: "checkpoints/FastSurferCC_segmentation_v1.0.0.pkl" + localization: "checkpoints/FastSurferCC_localization_v1.0.0.pkl" diff --git a/CorpusCallosum/data/constants.py b/CorpusCallosum/data/constants.py index 60125bd2..e3767349 100644 --- a/CorpusCallosum/data/constants.py +++ b/CorpusCallosum/data/constants.py @@ -17,7 +17,6 @@ "upright_lta": "transforms/upright.lta", "orient_volume_lta": "transforms/orient_volume.lta", "orig_space_segmentation": "mri/segmentation_orig_space.mgz", - "debug_image": "stats/cc_postprocessing.png", - "qc_view": "qc-snapshots/corpus_callosum.png", - "qc_view3d": "qc-snapshots/corpus_callosum_thickness.png" + "debug_image": "qc_snapshots/corpus_callosum.png", + "thickness_image": "qc_snapshots/corpus_callosum_thickness.png" } \ No newline at end of file diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index e5064924..342c795b 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -6,11 +6,6 @@ import nibabel as nib import numpy as np import torch -from CorpusCallosum.localization import localization_inference -from recon_surf import lta -from recon_surf.align_points import find_rigid -from CorpusCallosum.segmentation import segmentation_inference, segmentation_postprocessing -from FastSurferCNN.data_loader.conform import is_conform from CorpusCallosum.data.constants import ( CC_LABEL, @@ -29,7 +24,6 @@ save_nifti_background, ) from CorpusCallosum.localization import localization_inference -from CorpusCallosum.registration import find_rigid, lta from CorpusCallosum.registration.mapping_helpers import ( apply_transform_and_map_volume, apply_transform_to_pt, @@ -40,6 +34,8 @@ from CorpusCallosum.segmentation import segmentation_inference, segmentation_postprocessing from CorpusCallosum.shape.cc_postprocessing import create_visualization, process_slices from FastSurferCNN.data_loader.conform import is_conform +from recon_surf import lta +from recon_surf.align_points import find_rigid def options_parse() -> argparse.Namespace: @@ -155,17 +151,22 @@ def options_parse() -> argparse.Namespace: parser.add_argument( "--debug_image_path", type=str, - help="Path for debug visualization image (default: subject_dir/stats/cc_postprocessing.png)", + help="Path for debug visualization image (default: subject_dir/qc_snapshots/cc_postprocessing.png)", default=None, ) - - # Template saving argument parser.add_argument( "--save_template", type=str, help="Directory path where to save contours.txt and thickness_values.txt files", default=None, ) + parser.add_argument( + "--thickness_image_path", + type=str, + help="Path for thickness image (default: subject_dir/qc_snapshots/corpus_callosum_thickness_3d.png)", + default=None, + ) + args = parser.parse_args() @@ -375,6 +376,7 @@ def main( orig_space_segmentation_path: str | Path = None, debug_image_path: str | Path = None, save_template: str | Path | None = None, + thickness_image_path: str | Path = None, cpu: bool = False, ) -> None: """Main pipeline function for corpus callosum analysis. @@ -408,6 +410,8 @@ def main( (default: output_dir/mri/segmentation_orig_space.mgz) debug_image_path: Path for debug visualization image (default: output_dir/stats/cc_postprocessing.png) save_template: Directory path where to save contours.txt and thickness_values.txt files + thickness_image_path: Path for thickness image + (default: output_dir/qc_snapshots/corpus_callosum_thickness_3d.png) cpu: Force CPU usage even when CUDA is available The function saves multiple outputs to specified paths or default locations in output_dir: @@ -451,7 +455,7 @@ def main( "center around the mid-sagittal plane)" ) - if not is_conform(orig, conform_vox_size=orig.header.get_zooms()[0]): + if not is_conform(orig): print("Error: MRI is not conformed, please run conform.py or mri_convert to conform the image.") exit(1) diff --git a/CorpusCallosum/localization/localization_inference.py b/CorpusCallosum/localization/localization_inference.py index a60c28c2..db7cc84e 100644 --- a/CorpusCallosum/localization/localization_inference.py +++ b/CorpusCallosum/localization/localization_inference.py @@ -1,8 +1,14 @@ +from pathlib import Path + import numpy as np import torch from monai import transforms -from monai.networks.nets import DenseNet as DenseNet_monai -from transforms.localization_transforms import CropAroundACPCFixedSize +from monai.networks.nets import DenseNet + +from CorpusCallosum.transforms.localization_transforms import CropAroundACPCFixedSize +from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML +from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults +from FastSurferCNN.download_checkpoints import main as download_checkpoints def load_model(checkpoint_path, device=None): @@ -20,7 +26,7 @@ def load_model(checkpoint_path, device=None): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize model architecture (must match training) - model = DenseNet_monai( # densenet201 + model = DenseNet( # densenet201 spatial_dims=2, in_channels=3, out_channels=4, @@ -32,9 +38,17 @@ def load_model(checkpoint_path, device=None): norm=("batch", {"affine": True}), dropout_prob=0.2 ) + + download_checkpoints(cc=True) + cc_config = load_checkpoint_config_defaults( + "checkpoint", + filename=CC_YAML, + ) + checkpoint_path = cc_config['localization'] + # Load state dict - if isinstance(checkpoint_path, str): + if isinstance(checkpoint_path, str) or isinstance(checkpoint_path, Path): state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True) if isinstance(state_dict, dict) and 'model_state_dict' in state_dict: state_dict = state_dict['model_state_dict'] diff --git a/CorpusCallosum/segmentation/segmentation_inference.py b/CorpusCallosum/segmentation/segmentation_inference.py index 154bd9d6..3334b49d 100644 --- a/CorpusCallosum/segmentation/segmentation_inference.py +++ b/CorpusCallosum/segmentation/segmentation_inference.py @@ -2,8 +2,11 @@ import numpy as np import torch from monai import transforms -from transforms.segmentation_transforms import CropAroundACPC +from CorpusCallosum.transforms.segmentation_transforms import CropAroundACPC +from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML +from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults +from FastSurferCNN.download_checkpoints import main as download_checkpoints from FastSurferCNN.models.networks import FastSurferVINN @@ -43,6 +46,13 @@ def load_model(checkpoint_path, device=None): } model = FastSurferVINN(params) + download_checkpoints(cc=True) + cc_config = load_checkpoint_config_defaults( + "checkpoint", + filename=CC_YAML, + ) + checkpoint_path = cc_config['segmentation'] + #model = torch.load(checkpoint_path, map_location=device, weights_only=False) weights = torch.load(checkpoint_path, weights_only=True, map_location=device) model.load_state_dict(weights) diff --git a/CorpusCallosum/shape/cc_endpoint_heuristic.py b/CorpusCallosum/shape/cc_endpoint_heuristic.py index aafd0af6..868d07c2 100644 --- a/CorpusCallosum/shape/cc_endpoint_heuristic.py +++ b/CorpusCallosum/shape/cc_endpoint_heuristic.py @@ -43,7 +43,14 @@ def get_endpoints(cc_mask, AC_2d, PC_2d, resolution, return_coordinates=True, co # gaussian_cc_mask = scipy.ndimage.gaussian_filter(gaussian_cc_mask, sigma=1.0) contour = skimage.measure.find_contours(gaussian_cc_mask, level=0.5)[0].T - contour = lapy.tria_mesh.TriaMesh.iterative_resample_polygon(contour.T, 701).T + + + # Add z=0 coordinate to make 3D, then remove it after resampling + contour_3d = np.vstack([contour, np.zeros(contour.shape[1])]) + contour_3d = lapy.tria_mesh.TriaMesh._TriaMesh__resample_polygon(contour_3d.T, 701).T + contour = contour_3d[:2] + + contour = contour[:, :-1] rotated_AC_2d = np.array(rotated_AC_2d).astype(float) diff --git a/CorpusCallosum/shape/cc_mesh.py b/CorpusCallosum/shape/cc_mesh.py index 72399622..645a1c88 100644 --- a/CorpusCallosum/shape/cc_mesh.py +++ b/CorpusCallosum/shape/cc_mesh.py @@ -1,4 +1,5 @@ import tempfile +from pathlib import Path import lapy import matplotlib @@ -1026,7 +1027,7 @@ def __create_cc_viewmat(): return viewmat - def snap_cc_picture(self, output_path: str): + def snap_cc_picture(self, output_path: str, fssurf_file: str | None = None, overlay_file: str | None = None): """Snap a picture of the corpus callosum mesh. Takes a snapshot of the mesh from a predefined viewpoint, with optional thickness @@ -1034,7 +1035,10 @@ def snap_cc_picture(self, output_path: str): Args: output_path (str): Path where to save the snapshot image. - + fssurf_file (str, optional): Path to a FreeSurfer surface file to use for the snapshot - if not provided, + the mesh is saved to a temporary file. + overlay_file (str, optional): Path to a FreeSurfer overlay file to use for the snapshot - if not provided, + the mesh is saved to a temporary file. Note: This method uses a temporary file to store the mesh and overlay data during the snapshot process. @@ -1045,17 +1049,22 @@ def snap_cc_picture(self, output_path: str): return # create temp file - temp_file = tempfile.NamedTemporaryFile(suffix=".fssurf", delete=True) - self.write_fssurf(temp_file.name) - - # Write thickness values as overlay - if hasattr(self, "mesh_vertex_colors"): - overlay_file = tempfile.NamedTemporaryFile(suffix=".w", delete=True) - # Write thickness values in FreeSurfer .w format - nib.freesurfer.write_morph_data(overlay_file.name, self.mesh_vertex_colors) - overlaypath = overlay_file.name + if fssurf_file is None: + temp_file = tempfile.NamedTemporaryFile(suffix=".fssurf", delete=True) + self.write_fssurf(temp_file.name) + else: + temp_file = Path(fssurf_file) + + if overlay_file is None: + if hasattr(self, "mesh_vertex_colors"): + overlay_file = tempfile.NamedTemporaryFile(suffix=".w", delete=True) + # Write thickness values in FreeSurfer .w format + nib.freesurfer.write_morph_data(overlay_file.name, self.mesh_vertex_colors) + overlaypath = overlay_file.name + else: + overlaypath = None else: - overlaypath = None + overlaypath = Path(overlay_file).name snap1( temp_file.name, diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/cc_postprocessing.py index 3d4bcc31..3017b3dc 100644 --- a/CorpusCallosum/shape/cc_postprocessing.py +++ b/CorpusCallosum/shape/cc_postprocessing.py @@ -202,7 +202,8 @@ def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thi def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac_coords, pc_coords, num_thickness_points, subdivisions, subdivision_method, contour_smoothing, - output_dir, debug_image_path=None, vox_size=None, verbose=False, save_template=None): + output_dir, debug_image_path=None, thickness_image_path=None, vox_size=None, verbose=False, + save_template=None): """Process corpus callosum slices based on selection mode. Handles the processing of either a single middle slice, all slices, or a specific slice, @@ -324,8 +325,9 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac cc_mesh.create_mesh() cc_mesh.smooth_(1) cc_mesh.plot_mesh() - cc_mesh.write_vtk(str(output_dir / 'cc_mesh.vtk')) - cc_mesh.snap_cc_picture(str(output_dir / 'cc_mesh_snap.png')) + #cc_mesh.write_vtk(str(output_dir / 'cc_mesh.vtk')) + cc_mesh.snap_cc_picture(str(output_dir / thickness_image_path)) + if not slice_results: print("Error: No valid slices were found for postprocessing") diff --git a/CorpusCallosum/utils/checkpoint.py b/CorpusCallosum/utils/checkpoint.py new file mode 100644 index 00000000..2fd19f21 --- /dev/null +++ b/CorpusCallosum/utils/checkpoint.py @@ -0,0 +1,17 @@ +# Copyright 2024 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT + +YAML_DEFAULT = FASTSURFER_ROOT / "CorpusCallosum/config/checkpoint_paths.yaml" diff --git a/CorpusCallosum/visualization/visualization.py b/CorpusCallosum/visualization/visualization.py index 1207c548..4ee3fd5e 100644 --- a/CorpusCallosum/visualization/visualization.py +++ b/CorpusCallosum/visualization/visualization.py @@ -186,6 +186,8 @@ def plot_contours( ax[0].set_xlim(reference_contour[0, :].min() - padding, reference_contour[0, :].max() + padding) ax[0].set_ylim((-reference_contour[1, :]).max() + padding, (-reference_contour[1, :]).min() - padding) + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + plt.savefig(output_path, dpi=300, bbox_inches="tight") # plt.show() diff --git a/FastSurferCNN/download_checkpoints.py b/FastSurferCNN/download_checkpoints.py index 23f65feb..7d7a620a 100644 --- a/FastSurferCNN/download_checkpoints.py +++ b/FastSurferCNN/download_checkpoints.py @@ -17,6 +17,7 @@ from CerebNet.utils.checkpoint import ( YAML_DEFAULT as CEREBNET_YAML, ) +from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML from FastSurferCNN.utils import PLANES from FastSurferCNN.utils.checkpoint import ( YAML_DEFAULT as VINN_YAML, @@ -26,9 +27,7 @@ get_checkpoints, load_checkpoint_config_defaults, ) -from HypVINN.utils.checkpoint import ( - YAML_DEFAULT as HYPVINN_YAML, -) +from HypVINN.utils.checkpoint import YAML_DEFAULT as HYPVINN_YAML class ConfigCache: @@ -40,9 +39,12 @@ def cerebnet_url(self): def hypvinn_url(self): return load_checkpoint_config_defaults("url", filename=HYPVINN_YAML) + + def cc_url(self): + return load_checkpoint_config_defaults("url", filename=CC_YAML) def all_urls(self): - return self.vinn_url() + self.cerebnet_url() + self.hypvinn_url() + return self.vinn_url() + self.cerebnet_url() + self.hypvinn_url() + self.cc_url() defaults = ConfigCache() @@ -72,6 +74,12 @@ def make_parser(): action="store_true", help="Check and download CerebNet default checkpoints", ) + parser.add_argument( + "--cc", + default=False, + action="store_true", + help="Check and download Corpus Callosum default checkpoints", + ) parser.add_argument( "--hypvinn", @@ -99,16 +107,20 @@ def make_parser(): def main( - vinn: bool, - cerebnet: bool, - hypvinn: bool, - all: bool, - files: list[str], + vinn: bool = False, + cerebnet: bool = False, + hypvinn: bool = False, + cc: bool = False, + all: bool = False, + files: list[str] = None, url: str | None = None, ) -> int | str: - if not vinn and not files and not cerebnet and not hypvinn and not all: + if not vinn and not files and not cerebnet and not hypvinn and not cc and not all: return ("Specify either files to download or --vinn, --cerebnet, " "--hypvinn, or --all, see help -h.") + + if files is None: + files = [] try: # FastSurferVINN checkpoints @@ -141,6 +153,16 @@ def main( *(hypvinn_config[plane] for plane in PLANES), urls=defaults.hypvinn_url() if url is None else [url], ) + # Corpus Callosum checkpoints + if cc or all: + cc_config = load_checkpoint_config_defaults( + "checkpoint", + filename=CC_YAML, + ) + get_checkpoints( + *(cc_config[model] for model in cc_config.keys()), + urls=defaults.cc_url() if url is None else [url], + ) for fname in files: check_and_download_ckpts( fname, diff --git a/doc/api/recon_surf.rst b/doc/api/recon_surf.rst index 4e19a65b..0387d24e 100644 --- a/doc/api/recon_surf.rst +++ b/doc/api/recon_surf.rst @@ -13,7 +13,6 @@ recon_surf fs_balabels map_surf_label N4_bias_correct - paint_cc_into_pred rewrite_oriented_surface rewrite_mc_surface rotate_sphere From f03997c4bbf5707e6d7ceda9f39a1901d493e49e Mon Sep 17 00:00:00 2001 From: ClePol Date: Mon, 22 Sep 2025 17:18:32 +0200 Subject: [PATCH 09/68] added writing soft labels fixed QC image for multiple slices added additional output files (surfaces) adjusted naming added logging and better file naming --- CorpusCallosum/data/constants.py | 32 +++- .../data/generate_fsaverage_centroids.py | 6 + CorpusCallosum/data/read_write.py | 5 + CorpusCallosum/fastsurfer_cc.py | 167 ++++++++++++++---- .../registration/mapping_helpers.py | 6 + CorpusCallosum/shape/cc_mesh.py | 7 + CorpusCallosum/shape/cc_postprocessing.py | 125 +++++++++---- CorpusCallosum/shape/cc_thickness.py | 13 +- CorpusCallosum/visualization/visualization.py | 2 +- recon_surf/align_points.py | 14 +- 10 files changed, 280 insertions(+), 97 deletions(-) diff --git a/CorpusCallosum/data/constants.py b/CorpusCallosum/data/constants.py index e3767349..b41afc7e 100644 --- a/CorpusCallosum/data/constants.py +++ b/CorpusCallosum/data/constants.py @@ -7,16 +7,30 @@ FSAVERAGE_MIDDLE = 128 # Middle slice index in fsaverage space CC_LABEL = 192 # Label value for corpus callosum in segmentation FORNIX_LABEL = 250 # Label value for fornix in segmentation +SUBSEGEMNT_LABELS = [251, 252, 253, 254, 255] # labels for subsegments in segmentation STANDARD_OUTPUT_PATHS = { - "upright_volume": "mri/upright_volume.mgz", - "segmentation": "mri/cc_segmentation.mgz", - "postproc_results": "stats/cc_postproc_results.json", - "cc_markers": "stats/cc_markers.json", - "upright_lta": "transforms/upright.lta", - "orient_volume_lta": "transforms/orient_volume.lta", - "orig_space_segmentation": "mri/segmentation_orig_space.mgz", - "debug_image": "qc_snapshots/corpus_callosum.png", - "thickness_image": "qc_snapshots/corpus_callosum_thickness.png" + ## images + "upright_volume": None, # orig.mgz mapped to upright space + ## segmentations + "segmentation": "mri/callosum_seg_upright.mgz", # corpus callosum segmentation in upright space + "orig_space_segmentation": "mri/callosum_seg_aseg_space.mgz", # cc segmentation in input segmentations space + "softlabels_cc": "mri/callosum_seg_soft.mgz", # cc softlabels in upright space + "softlabels_fn": "mri/fornix_seg_soft.mgz", # fornix softlabels in upright space + "softlabels_background": "mri/background_seg_soft.mgz", # background softlabels in upright space + ## stats + "cc_markers": "stats/callosum.CC.midslice.json", # cc metrics for middle slice + "postproc_results": "stats/callosum.CC.all_slices.json", # cc metrics for all slices + ## transforms + "upright_lta": "mri/transforms/cc_up.lta", # lta transform from orig to upright space + "orient_volume_lta": "mri/transforms/orient_volume.lta", # lta transform from orig to upright+acpc corrected space + ## qc + "debug_image": "qc_snapshots/callosum.png", # debug image of cc contours + "thickness_image": "qc_snapshots/callosum_thickness.png", # whippersnappy 3D image of cc thickness + "cc_html": "qc_snapshots/corpus_callosum.html", # plotly cc visualization + ## surface + "surf_file": "surf/callosum.surf", # cc surface file + "overlay_file": "surf/callosum.thickness.w", # cc surface overlay file + "vtk_file": "qc_snapshots/callosum_mesh.vtk", # vtk file of cc mesh } \ No newline at end of file diff --git a/CorpusCallosum/data/generate_fsaverage_centroids.py b/CorpusCallosum/data/generate_fsaverage_centroids.py index 40a97095..2cc827d8 100644 --- a/CorpusCallosum/data/generate_fsaverage_centroids.py +++ b/CorpusCallosum/data/generate_fsaverage_centroids.py @@ -15,6 +15,10 @@ import numpy as np from read_write import convert_numpy_to_json_serializable, get_centroids_from_nib +import FastSurferCNN.utils.logging as logging + +logger = logging.get_logger(__name__) + def main(): """Generate and save fsaverage centroids to a static file.""" @@ -52,6 +56,7 @@ def main(): # Save centroids to JSON file centroids_output_path = Path(__file__).parent / "fsaverage_centroids.json" + logger.info(f"Saving fsaverage centroids to {centroids_output_path}") with open(centroids_output_path, 'w') as f: json.dump(centroids_serializable, f, indent=2) @@ -92,6 +97,7 @@ def main(): # Save combined data to JSON file combined_output_path = Path(__file__).parent / "fsaverage_data.json" + logger.info(f"Saving fsaverage affine and header data to {combined_output_path}") with open(combined_output_path, 'w') as f: json.dump(combined_data_serializable, f, indent=2) diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py index 5a07f605..b22e804e 100644 --- a/CorpusCallosum/data/read_write.py +++ b/CorpusCallosum/data/read_write.py @@ -3,6 +3,10 @@ import nibabel as nib import numpy as np +import FastSurferCNN.utils.logging as logging + +logger = logging.get_logger(__name__) + def run_in_background(function, debug=False, *args, **kwargs): """Run a function in the background using multiprocessing. @@ -102,6 +106,7 @@ def save_nifti_background(io_processes, data, affine, header, filepath): header: NIfTI header object containing metadata filepath (str): Path where the image should be saved """ + logger.info(f"Saving NIfTI image to {filepath}") io_processes.append(run_in_background(nib.save, False, nib.MGHImage(data, affine, header), filepath)) diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 342c795b..5e79caaa 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -7,6 +7,7 @@ import numpy as np import torch +import FastSurferCNN.utils.logging as logging from CorpusCallosum.data.constants import ( CC_LABEL, FSAVERAGE_CENTROIDS_PATH, @@ -37,6 +38,8 @@ from recon_surf import lta from recon_surf.align_points import find_rigid +logger = logging.get_logger(__name__) + def options_parse() -> argparse.Namespace: """Parse command line arguments for the pipeline.""" @@ -66,10 +69,8 @@ def options_parse() -> argparse.Namespace: "Required if --in_mri and --aseg are not both provided.", default=None, ) - parser.add_argument("--debug_output_dir", type=str, required=False, default=None) - parser.add_argument("--verbose", action="store_true", help="Enable verbose output and debug plots") - - # CC shape arguments + parser.add_argument("--debug_output_dir", type=str, required=False, default=None, + help="Directory for debug output (default: subject_dir/qc_snapshots)") parser.add_argument( "--num_thickness_points", type=int, default=100, help="Number of points for thickness estimation." ) @@ -77,7 +78,7 @@ def options_parse() -> argparse.Namespace: "--subdivisions", type=float, nargs="+", - default=[1 / 6, 1 / 2, 2 / 3, 3 / 4], + default=[1/6, 1/2, 2/3, 3/4], help="List of subdivision fractions for the corpus callosum subsegmentation.", ) parser.add_argument( @@ -101,57 +102,57 @@ def options_parse() -> argparse.Namespace: parser.add_argument( "--slice_selection", type=str, - default="middle", + default="all", help="Which slices to process. Options: 'middle' (default), 'all', or a specific slice number.", ) - - # Output path arguments parser.add_argument( "--upright_volume_path", type=str, - help="Path for upright volume output (default: subject_dir/stats/upright_volume.mgz)", + help=f"Path for upright volume output (default: subject_dir/{STANDARD_OUTPUT_PATHS['upright_volume']})", default=None, ) parser.add_argument( "--segmentation_path", type=str, - help="Path for segmentation output (default: subject_dir/stats/cc_segmentation.mgz)", + help=f"Path for segmentation output (default: subject_dir/{STANDARD_OUTPUT_PATHS['segmentation']})", default=None, ) parser.add_argument( "--postproc_results_path", type=str, - help="Path for postprocessing results (default: subject_dir/stats/cc_postproc_results.json)", + help=f"Path for postprocessing results (default: subject_dir/{STANDARD_OUTPUT_PATHS['postproc_results']})", default=None, ) parser.add_argument( "--cc_markers_path", type=str, - help="Path for CC markers output (default: subject_dir/stats/cc_markers.json)", + help=f"Path for CC markers output (default: subject_dir/{STANDARD_OUTPUT_PATHS['cc_markers']})", default=None, ) parser.add_argument( "--upright_lta_path", type=str, - help="Path for upright LTA transform (default: subject_dir/transforms/upright.lta)", + help=f"Path for upright LTA transform (default: subject_dir/{STANDARD_OUTPUT_PATHS['upright_lta']})", default=None, ) parser.add_argument( "--orient_volume_lta_path", type=str, - help="Path for orientation volume LTA transform (default: subject_dir/transforms/orient_volume.lta)", + help="Path for orientation volume LTA transform " + f"(default: subject_dir/{STANDARD_OUTPUT_PATHS['orient_volume_lta']})", default=None, ) parser.add_argument( "--orig_space_segmentation_path", type=str, - help="Path for segmentation in original space (default: subject_dir/mri/segmentation_orig_space.mgz)", + help="Path for segmentation in original space " + f"(default: subject_dir/{STANDARD_OUTPUT_PATHS['orig_space_segmentation']})", default=None, ) parser.add_argument( "--debug_image_path", type=str, - help="Path for debug visualization image (default: subject_dir/qc_snapshots/cc_postprocessing.png)", + help=f"Path for debug visualization image (default: subject_dir/{STANDARD_OUTPUT_PATHS['debug_image']})", default=None, ) parser.add_argument( @@ -163,9 +164,52 @@ def options_parse() -> argparse.Namespace: parser.add_argument( "--thickness_image_path", type=str, - help="Path for thickness image (default: subject_dir/qc_snapshots/corpus_callosum_thickness_3d.png)", + help=f"Path for thickness image (default: subject_dir/{STANDARD_OUTPUT_PATHS['thickness_image']})", + default=None, + ) + parser.add_argument( + "--surf_file_path", + type=str, + help=f"Path for surf file (default: subject_dir/{STANDARD_OUTPUT_PATHS['surf_file']})", + default=None, + ) + parser.add_argument( + "--overlay_file_path", + type=str, + help=f"Path for overlay file (default: subject_dir/{STANDARD_OUTPUT_PATHS['overlay_file']})", + default=None, + ) + parser.add_argument( + "--cc_html_path", + type=str, + help=f"Path for CC HTML file (default: subject_dir/{STANDARD_OUTPUT_PATHS['cc_html']})", + default=None, + ) + parser.add_argument( + "--vtk_file_path", + type=str, + help=f"Path for vtk file (default: subject_dir/{STANDARD_OUTPUT_PATHS['vtk_file']})", + default=None, + ) + parser.add_argument( + "--softlabels_cc_path", + type=str, + help=f"Path for cc softlabels (default: subject_dir/{STANDARD_OUTPUT_PATHS['softlabels_cc']})", default=None, ) + parser.add_argument( + "--softlabels_fn_path", + type=str, + help=f"Path for fornix softlabels (default: subject_dir/{STANDARD_OUTPUT_PATHS['softlabels_fn']})", + default=None, + ) + parser.add_argument( + "--softlabels_background_path", + type=str, + help=f"Path for background softlabels (default: subject_dir/{STANDARD_OUTPUT_PATHS['softlabels_background']})", + default=None, + ) + parser.add_argument("--verbose", action="store_true", help="Enable verbose (shows output paths)", default=False) args = parser.parse_args() @@ -191,7 +235,7 @@ def options_parse() -> argparse.Namespace: # Set default output paths if not provided for key, value in STANDARD_OUTPUT_PATHS.items(): - if not getattr(args, f"{key}_path"): + if not getattr(args, f"{key}_path") and value is not None: setattr(args, f"{key}_path", str(subject_dir_path / value)) # Set output_dir to subject_dir @@ -262,7 +306,7 @@ def centroid_registration(aseg_nib, verbose=False): return orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, fsaverage_hires_affine, fsaverage_header -def localize_ac_pc(midslices, aseg_nib, orig_fsaverage_vox2vox, model_localization, slices_to_analyze, verbose=False): +def localize_ac_pc(midslices, aseg_nib, orig_fsaverage_vox2vox, model_localization, slices_to_analyze): """Localize anterior and posterior commissure points in the brain. Uses a trained model to detect AC and PC points in mid-sagittal slices, @@ -275,15 +319,12 @@ def localize_ac_pc(midslices, aseg_nib, orig_fsaverage_vox2vox, model_localizati fsaverage_hires_affine (np.ndarray): High-resolution fsaverage affine matrix model_localization: Trained model for AC-PC detection slices_to_analyze (int): Number of slices to process - verbose (bool): Whether to print progress information Returns: tuple: Contains: - ac_coords (np.ndarray): Coordinates of the anterior commissure - pc_coords (np.ndarray): Coordinates of the posterior commissure """ - if verbose: - print("Localization and segmentation inference") # get center of third ventricle from aseg and map to fsaverage space third_ventricle_mask = aseg_nib.get_fdata() == 4 @@ -373,10 +414,17 @@ def main( cc_markers_path: str | Path = None, upright_lta_path: str | Path = None, orient_volume_lta_path: str | Path = None, + surf_file_path: str | Path = None, + overlay_file_path: str | Path = None, + cc_html_path: str | Path = None, + vtk_file_path: str | Path = None, orig_space_segmentation_path: str | Path = None, debug_image_path: str | Path = None, save_template: str | Path | None = None, thickness_image_path: str | Path = None, + softlabels_cc_path: str | Path = None, + softlabels_fn_path: str | Path = None, + softlabels_background_path: str | Path = None, cpu: bool = False, ) -> None: """Main pipeline function for corpus callosum analysis. @@ -412,6 +460,13 @@ def main( save_template: Directory path where to save contours.txt and thickness_values.txt files thickness_image_path: Path for thickness image (default: output_dir/qc_snapshots/corpus_callosum_thickness_3d.png) + surf_file_path: Path for surf file (default: output_dir/surf/callosum.surf) + overlay_file_path: Path for overlay file (default: output_dir/mri/callosum_seg_aseg_space.mgz) + cc_html_path: Path for CC HTML file (default: output_dir/qc_snapshots/corpus_callosum.html) + vtk_file_path: Path for vtk file (default: output_dir/qc_snapshots/callosum_mesh.vtk) + softlabels_cc_path: Path for cc softlabels (default: output_dir/mri/callosum_seg_soft.mgz) + softlabels_fn_path: Path for fornix softlabels (default: output_dir/mri/fornix_seg_soft.mgz) + softlabels_background_path: Path for background softlabels (default: output_dir/mri/background_seg_soft.mgz) cpu: Force CPU usage even when CUDA is available The function saves multiple outputs to specified paths or default locations in output_dir: @@ -426,6 +481,15 @@ def main( if subdivisions is None: subdivisions = [1 / 6, 1 / 2, 2 / 3, 3 / 4] + # Set up logging if verbose mode is enabled + if verbose: + logging.setup_logging(None) # Log to stdout only + + logger.info("Starting corpus callosum analysis pipeline") + logger.info(f"Input MRI: {in_mri_path}") + logger.info(f"Input segmentation: {aseg_path}") + logger.info(f"Output directory: {output_dir}") + # Convert all paths to Path objects in_mri_path = Path(in_mri_path) aseg_path = Path(aseg_path) @@ -450,33 +514,39 @@ def main( slices_to_analyze += 1 if verbose: - print( + logger.info( f"Segmenting {slices_to_analyze} slices (5 mm width at {orig.header.get_zooms()[0]} mm resolution, " "center around the mid-sagittal plane)" ) if not is_conform(orig): - print("Error: MRI is not conformed, please run conform.py or mri_convert to conform the image.") + logger.error("Error: MRI is not conformed, please run conform.py or mri_convert to conform the image.") exit(1) # load models device = torch.device("cuda" if torch.cuda.is_available() and not cpu else "cpu") + logger.info(f"Using device: {device}") + + logger.info("Loading localization model") model_localization = localization_inference.load_model( str(Path(WEIGHTS_PATH) / "localization_weights_acpc.pth"), device=device ) + logger.info("Loading segmentation model") model_segmentation = segmentation_inference.load_model( str(Path(WEIGHTS_PATH) / "segmentation_weights_cc_fn.pth"), device=device ) aseg_nib = nib.load(aseg_path) + logger.info("Performing centroid registration to fsaverage space") orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, fsaverage_hires_affine, fsaverage_header = centroid_registration( - aseg_nib, verbose + aseg_nib, verbose=False ) if verbose: - print("Interpolating midplane") + logger.info("Interpolating midplane") + logger.info("Interpolating midplane slices") # this is a fast interpolation to not block the main thread midslices = interpolate_midplane(orig, orig_fsaverage_vox2vox, slices_to_analyze) @@ -495,13 +565,37 @@ def main( ) #### do localization and segmentation inference + logger.info("Starting AC/PC localization") ac_coords, pc_coords = localize_ac_pc( - midslices, aseg_nib, orig_fsaverage_vox2vox, model_localization, slices_to_analyze, verbose + midslices, aseg_nib, orig_fsaverage_vox2vox, model_localization, slices_to_analyze ) + logger.info("Starting corpus callosum segmentation") segmentation, outputs_soft = segment_cc( midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, slices_to_analyze ) + + # calculate affine for segmentation volume + orig_to_seg = np.eye(4) + orig_to_seg[0, 3] = -FSAVERAGE_MIDDLE + slices_to_analyze // 2 + seg_affine = fsaverage_hires_affine + seg_affine = seg_affine @ np.linalg.inv(orig_to_seg) + + # save softlabels + if softlabels_background_path is not None: + if verbose: + logger.info(f"Saving background softlabels to {softlabels_background_path}") + save_nifti_background(IO_processes, outputs_soft[..., 0], seg_affine, orig.header, softlabels_background_path) + if softlabels_cc_path is not None: + if verbose: + logger.info(f"Saving cc softlabels to {softlabels_cc_path}") + save_nifti_background(IO_processes, outputs_soft[..., 1], seg_affine, orig.header, softlabels_cc_path) + if softlabels_fn_path is not None: + if verbose: + logger.info(f"Saving fornix softlabels to {softlabels_fn_path}") + save_nifti_background(IO_processes, outputs_soft[..., 2], seg_affine, orig.header, softlabels_fn_path) + + # map soft labels to original space (in parallel because this takes a while) IO_processes.append( run_in_background( @@ -520,6 +614,7 @@ def main( temp_seg_affine = fsaverage_hires_affine @ np.linalg.inv(np.eye(4)) # Process slices based on selection mode + logger.info(f"Processing slices with selection mode: {slice_selection}") slice_results, slice_io_processes = process_slices( segmentation=segmentation, slice_selection=slice_selection, @@ -533,6 +628,11 @@ def main( contour_smoothing=contour_smoothing, output_dir=output_dir, debug_image_path=debug_image_path, + surf_file_path=surf_file_path, + overlay_file_path=overlay_file_path, + cc_html_path=cc_html_path, + vtk_file_path=vtk_file_path, + thickness_image_path=thickness_image_path, vox_size=orig.header.get_zooms()[0], verbose=verbose, save_template=save_template, @@ -575,7 +675,7 @@ def main( json.dump(per_slice_output_dict, f, indent=4) if verbose: - print(f"Multiple slice post-processing results saved to {postproc_results_path}") + logger.info(f"Multiple slice post-processing results saved to {postproc_results_path}") ########## Save outputs ########## @@ -640,11 +740,7 @@ def main( get_mapping_to_standard_space(orig, ac_coords_3d, pc_coords_3d, orig_fsaverage_vox2vox, output_dir) ) - # save segmentation with fitting affine - orig_to_seg = np.eye(4) - orig_to_seg[0, 3] = -FSAVERAGE_MIDDLE + slices_to_analyze // 2 - seg_affine = fsaverage_hires_affine - seg_affine = seg_affine @ np.linalg.inv(orig_to_seg) + save_nifti_background(IO_processes, segmentation, seg_affine, orig.header, segmentation_path) # write output dict as csv @@ -660,16 +756,19 @@ def main( # Convert numpy arrays to lists for JSON serialization output_dict = convert_numpy_to_json_serializable(output_dict) + logger.info(f"Saving CC markers to {cc_markers_path}") with open(cc_markers_path, "w") as f: json.dump(output_dict, f, indent=4) # save lta to fsaverage space + logger.info(f"Saving LTA to fsaverage space: {upright_lta_path}") lta.writeLTA(upright_lta_path, orig_fsaverage_ras2ras, aseg_path, aseg_nib.header, "fsaverage", fsaverage_header) # save lta to standardized space (fsaverage + nodding + ac to center) orig_to_standardized_ras2ras = ( orig.affine @ np.linalg.inv(standardized_to_orig_vox2vox) @ np.linalg.inv(orig.affine) ) + logger.info(f"Saving LTA to standardized space: {orient_volume_lta_path}") lta.writeLTA( orient_volume_lta_path, orig_to_standardized_ras2ras, in_mri_path, orig.header, in_mri_path, orig.header ) @@ -677,6 +776,8 @@ def main( for process in IO_processes: if process is not None: process.join() + + logger.info("CorpusCallosum analysis pipeline completed successfully") if __name__ == "__main__": diff --git a/CorpusCallosum/registration/mapping_helpers.py b/CorpusCallosum/registration/mapping_helpers.py index 3e4a0d1d..f4c06719 100644 --- a/CorpusCallosum/registration/mapping_helpers.py +++ b/CorpusCallosum/registration/mapping_helpers.py @@ -2,6 +2,10 @@ import numpy as np from scipy.ndimage import affine_transform +import FastSurferCNN.utils.logging as logging + +logger = logging.get_logger(__name__) + def make_midplane_affine(orig_affine, slices_to_analyze=1, offset=4): """ @@ -190,6 +194,7 @@ def apply_transform_and_map_volume( order=order, ) if output_path is not None: + logger.info(f"Saving transformed volume to {output_path}") nib.save(nib.MGHImage(transformed, affine, header), output_path) return transformed @@ -280,6 +285,7 @@ def map_softlabels_to_orig( ) if orig_space_segmentation_path is not None: + logger.info(f"Saving segmentation in original space to {orig_space_segmentation_path}") nib.save( nib.MGHImage(segmentation_orig_space, orig.affine, orig.header), orig_space_segmentation_path, diff --git a/CorpusCallosum/shape/cc_mesh.py b/CorpusCallosum/shape/cc_mesh.py index 645a1c88..f67a4577 100644 --- a/CorpusCallosum/shape/cc_mesh.py +++ b/CorpusCallosum/shape/cc_mesh.py @@ -13,6 +13,10 @@ from shape.cc_thickness import HiddenPrints, make_mesh_from_contour from whippersnappy.core import snap1 +import FastSurferCNN.utils.logging as logging + +logger = logging.get_logger(__name__) + class CC_Mesh(lapy.TriaMesh): """A class for representing and manipulating corpus callosum (CC) meshes. @@ -1114,6 +1118,7 @@ def save_contours(self, output_path: str): Args: output_path (str): Path where to save the CSV file. """ + logger.info(f"Saving contours to CSV file: {output_path}") with open(output_path, "w") as f: # Write header f.write("slice_idx,x,y\n") @@ -1185,6 +1190,7 @@ def save_thickness_values(self, output_path: str): Args: output_path (str): Path where to save the CSV file. """ + logger.info(f"Saving thickness data to CSV file: {output_path}") with open(output_path, "w") as f: # Write header f.write("slice_idx,thickness\n") @@ -1315,6 +1321,7 @@ def save_thickness_measurement_points(self, filename): Args: filename (str): Path where to save the CSV file. """ + logger.info(f"Saving thickness measurement points to CSV file: {filename}") with open(filename, "w") as f: f.write("slice_idx,vertex_idx\n") for slice_idx, vertex_indices in enumerate(self.original_thickness_vertices): diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/cc_postprocessing.py index 3017b3dc..761497dc 100644 --- a/CorpusCallosum/shape/cc_postprocessing.py +++ b/CorpusCallosum/shape/cc_postprocessing.py @@ -13,10 +13,15 @@ ) from shape.cc_thickness import cc_thickness, convert_to_ras +import FastSurferCNN.utils.logging as logging from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE from CorpusCallosum.data.read_write import run_in_background +from CorpusCallosum.utils.utils import HiddenPrints from CorpusCallosum.visualization.visualization import plot_contours +logger = logging.get_logger(__name__) + + # assert LIA orientation LIA_ORIENTATION = np.zeros((3,3)) LIA_ORIENTATION[0,0] = -1 @@ -40,16 +45,29 @@ def create_visualization(subdivision_method, result, midslices_data, output_imag Returns: Process object for background execution """ - title = f'CC Subsegmentation: {subdivision_method}{title_suffix}' - + title = f'CC Subsegmentation by {subdivision_method} {title_suffix}' + + args_dict = { + 'debug': False, + 'transformed': midslices_data, + 'split_contours': None, + 'split_contours_hofer_frahm': None, + 'midline_equidistant': result['midline_equidistant'], + 'levelpaths': result['levelpaths'], + 'output_path': output_image_path, + 'ac_coords': ac_coords, + 'pc_coords': pc_coords, + 'vox_size': vox_size, + 'title': title, + } + if subdivision_method == "shape": - return run_in_background(plot_contours, False, midslices_data, - result['split_contours'], None, result['midline_equidistant'], result['levelpaths'], - output_image_path, ac_coords, pc_coords, vox_size, title) + args_dict['split_contours'] = result['split_contours'] else: - return run_in_background(plot_contours, False, midslices_data, - None, result['split_contours_hofer_frahm'], result['midline_equidistant'], - result['levelpaths'], output_image_path, ac_coords, pc_coords, vox_size, title) + args_dict['split_contours_hofer_frahm'] = result['split_contours_hofer_frahm'] + + return run_in_background(plot_contours, **args_dict) + def create_slice_affine(temp_seg_affine, slice_idx, fsaverage_middle): @@ -93,7 +111,6 @@ def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thi subdivision_method (str): Method for contour subdivision ('shape', 'vertical', 'angular', or 'eigenvector') contour_smoothing (float): Gaussian sigma for contour smoothing - verbose (bool): Whether to print progress information Returns: dict or None: Dictionary containing measurements if successful, including: @@ -159,7 +176,7 @@ def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thi split_contours_hofer_frahm = split_contours.copy() elif subdivision_method == "angular": if not np.allclose(np.diff(subdivisions), np.diff(subdivisions)[0]): - print('Error: Angular subdivision method (Hampel) only supports equidistant subdivision, ' + logger.error('Error: Angular subdivision method (Hampel) only supports equidistant subdivision, ' f'but got: {subdivisions}') return None areas, split_contours = hampel_subdivide_contour(contour_acpc, num_rays=len(subdivisions), plot=False) @@ -202,8 +219,9 @@ def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thi def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac_coords, pc_coords, num_thickness_points, subdivisions, subdivision_method, contour_smoothing, - output_dir, debug_image_path=None, thickness_image_path=None, vox_size=None, verbose=False, - save_template=None): + output_dir, debug_image_path=None, thickness_image_path=None, vox_size=None, + save_template=None, surf_file_path=None, overlay_file_path=None, cc_html_path=None, + vtk_file_path=None, verbose=False): """Process corpus callosum slices based on selection mode. Handles the processing of either a single middle slice, all slices, or a specific slice, @@ -242,26 +260,32 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac slice_idx = segmentation.shape[0] // 2 slice_affine = create_slice_affine(temp_seg_affine, slice_idx, FSAVERAGE_MIDDLE) - result, contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx = process_slice(segmentation, - slice_idx, - ac_coords, - pc_coords, - slice_affine, - num_thickness_points, subdivisions, subdivision_method, contour_smoothing) + (result, contour_with_thickness, + anterior_endpoint_idx, posterior_endpoint_idx) = process_slice(segmentation, + slice_idx, + ac_coords, + pc_coords, + slice_affine, + num_thickness_points, + subdivisions, + subdivision_method, + contour_smoothing) cc_mesh.add_contour(0, contour_with_thickness[0], contour_with_thickness[1], start_end_idx=(anterior_endpoint_idx, posterior_endpoint_idx)) - if result is not None: + if result is not None and debug_image_path is not None: slice_results.append(result) # Create visualization + if verbose: + logger.info(f"Saving segmentation qc image to {debug_image_path}") IO_processes.append(create_visualization(subdivision_method, result, midslices, debug_image_path, ac_coords, pc_coords, vox_size)) else: - - cc_mesh = CC_Mesh(num_slices=segmentation.shape[0]) + num_slices = segmentation.shape[0] + cc_mesh = CC_Mesh(num_slices=num_slices) cc_mesh.set_acpc_coords(ac_coords, pc_coords) cc_mesh.set_resolution(1) # contour is always scaled to 1 mm @@ -276,7 +300,7 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac for slice_idx in range(start_slice, end_slice): if verbose: - print(f"Calculating CC measurements for slice {slice_idx+1} of {end_slice-start_slice}") + logger.info(f"Calculating CC measurements for slice {slice_idx+1} of {end_slice-start_slice}") # Update affine for this slice slice_affine = create_slice_affine(temp_seg_affine, slice_idx, FSAVERAGE_MIDDLE) @@ -297,40 +321,69 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac if result is not None: slice_results.append(result) + + debug_path_base, debug_path_ext = str(debug_image_path).rsplit('.', 1) + debug_path_with_postfix = f"{debug_path_base}_slice_{slice_idx}" + debug_output_path_slice = Path(f"{debug_path_with_postfix}.{debug_path_ext}").with_suffix('.png') - # For single slice mode, save to main directory - if slice_selection != "all": - output_subdir = output_dir - else: - # For all slices mode, create per-slice directory - output_subdir = output_dir / f'slice_{slice_idx}' - output_subdir.mkdir(exist_ok=True) - + if verbose: + logger.info(f"Saving segmentation qc image to {debug_output_path_slice}") + + current_slice_in_volume = midslices.shape[0] // 2 - num_slices // 2 + slice_idx # Create visualization for this slice - IO_processes.append(create_visualization(subdivision_method, result, midslices[slice_idx:slice_idx+1], - output_subdir, ac_coords, pc_coords, + IO_processes.append(create_visualization(subdivision_method, result, + midslices[current_slice_in_volume:current_slice_in_volume+1], + debug_output_path_slice, ac_coords, pc_coords, vox_size, f' (Slice {slice_idx})')) if save_template is not None: # Convert to Path object and ensure directory exists template_dir = Path(save_template) template_dir.mkdir(parents=True, exist_ok=True) + if verbose: + logger.info("Saving template files (contours.txt, thickness_values.txt, " + f"thickness_measurement_points.txt) to {template_dir}") cc_mesh.save_contours(str(template_dir / 'contours.txt')) cc_mesh.save_thickness_values(str(template_dir / 'thickness_values.txt')) cc_mesh.save_thickness_measurement_points(str(template_dir / 'thickness_measurement_points.txt')) - if len(cc_mesh.contours) > 1: + if len(cc_mesh.contours) > 1 and thickness_image_path is not None: cc_mesh.fill_thickness_values() cc_mesh.create_mesh() cc_mesh.smooth_(1) - cc_mesh.plot_mesh() + cc_mesh.plot_mesh(output_path=cc_html_path) + + if vtk_file_path is not None: + if verbose: + logger.info(f"Saving vtk file to {vtk_file_path}") + cc_mesh.write_vtk(str(vtk_file_path)) #cc_mesh.write_vtk(str(output_dir / 'cc_mesh.vtk')) - cc_mesh.snap_cc_picture(str(output_dir / thickness_image_path)) + + + cc_mesh.to_fs_coordinates() + + if overlay_file_path is not None: + if verbose: + logger.info(f"Saving overlay file to {overlay_file_path}") + cc_mesh.write_overlay(str(overlay_file_path)) + + if surf_file_path is not None: + if verbose: + logger.info(f"Saving surf file to {surf_file_path}") + cc_mesh.write_fssurf(str(surf_file_path)) + + + + if thickness_image_path is not None: + if verbose: + logger.info(f"Saving thickness image to {thickness_image_path}") + with HiddenPrints(): + cc_mesh.snap_cc_picture(str(output_dir / thickness_image_path)) if not slice_results: - print("Error: No valid slices were found for postprocessing") + logger.error("Error: No valid slices were found for postprocessing") exit(1) return slice_results, IO_processes \ No newline at end of file diff --git a/CorpusCallosum/shape/cc_thickness.py b/CorpusCallosum/shape/cc_thickness.py index 3d5be36d..afa8fd31 100644 --- a/CorpusCallosum/shape/cc_thickness.py +++ b/CorpusCallosum/shape/cc_thickness.py @@ -1,21 +1,10 @@ -import os -import sys - import meshpy.triangle as triangle import numpy as np import scipy.interpolate from lapy import Solver, TriaMesh from lapy.diffgeo import compute_rotated_f - -class HiddenPrints: - def __enter__(self): - self._original_stdout = sys.stdout - sys.stdout = open(os.devnull, "w") - - def __exit__(self, exc_type, exc_val, exc_tb): - sys.stdout.close() - sys.stdout = self._original_stdout +from CorpusCallosum.utils.utils import HiddenPrints def compute_curvature(path): diff --git a/CorpusCallosum/visualization/visualization.py b/CorpusCallosum/visualization/visualization.py index 4ee3fd5e..29b431da 100644 --- a/CorpusCallosum/visualization/visualization.py +++ b/CorpusCallosum/visualization/visualization.py @@ -127,7 +127,7 @@ def plot_contours( if split_contours_hofer_frahm is not None: NO_PLOTS += 1 - fig, ax = plt.subplots(1, NO_PLOTS, sharex=True, sharey=True, figsize=(15, 10)) + _, ax = plt.subplots(1, NO_PLOTS, sharex=True, sharey=True, figsize=(15, 10)) PLT_NUM = 0 diff --git a/recon_surf/align_points.py b/recon_surf/align_points.py index e549466d..e2ed0e5b 100755 --- a/recon_surf/align_points.py +++ b/recon_surf/align_points.py @@ -127,8 +127,7 @@ def find_rotation(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: return R - -def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: +def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray, verbose: bool = False) -> np.ndarray: """ Find rigid transformation matrix between two point sets. @@ -138,6 +137,8 @@ def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: Source points. p_dst : npt.NDArray Destination points. + verbose : bool, optional + Whether to print debug information, by default False. Returns ------- @@ -163,10 +164,11 @@ def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: T[:m, :m] = R T[:m, m] = t # compute disteances - dd = p_mov - p_dst - print(f"Initial avg SSD: {np.sum(dd * dd) / p_mov.shape[0]}") - dd = (np.transpose(R @ np.transpose(p_mov)) + t) - p_dst - print(f"Final avg SSD: {np.sum(dd * dd) / p_mov.shape[0]}") + if verbose: + dd = p_mov - p_dst + print(f"Initial avg SSD: {np.sum(dd * dd) / p_mov.shape[0]}") + dd = (np.transpose(R @ np.transpose(p_mov)) + t) - p_dst + print(f"Final avg SSD: {np.sum(dd * dd) / p_mov.shape[0]}") # return T, R, t return T From f06f2824a832a3c78437177ce4d53dd25bf51893 Mon Sep 17 00:00:00 2001 From: ClePol Date: Tue, 23 Sep 2025 15:18:24 +0200 Subject: [PATCH 10/68] cc painting script for reconsurf integration --- recon_surf/recon-surf.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recon_surf/recon-surf.sh b/recon_surf/recon-surf.sh index 6f3cdb5f..d415fd7b 100755 --- a/recon_surf/recon-surf.sh +++ b/recon_surf/recon-surf.sh @@ -630,8 +630,8 @@ fi #cmd="mri_cc -aseg $aseg_nocc -o aseg.auto.mgz -lta $mdir/transforms/cc_up.lta $subject" #RunIt "$cmd" "$LF" # add CC into aparc.DKTatlas+aseg.deep (not sure if this is really needed) -#cmd="$python ${binpath}paint_cc_into_pred.py -in_cc $mdir/aseg.auto.mgz -in_pred $asegdkt_segfile -out $mdir/aparc.DKTatlas+aseg.deep.withCC.mgz" -#RunIt "$cmd" "$LF" +cmd="$python ${binpath}/../CorpusCallosum/paint_cc_into_pred.py -in_cc $mdir/aseg.auto.mgz -in_pred $asegdkt_segfile -out $mdir/aparc.DKTatlas+aseg.deep.withCC.mgz" +RunIt "$cmd" "$LF" # ============================= FILLED ===================================================== From fde4739c511558230ef093dc5b5e7e9c971fef61 Mon Sep 17 00:00:00 2001 From: ClePol Date: Wed, 24 Sep 2025 16:08:56 +0200 Subject: [PATCH 11/68] added midslice based 3D subsegmentation in orig space outputs cleaned up logging and docs cleaned up folder creation and savepath logic --- CorpusCallosum/fastsurfer_cc.py | 87 ++++++--- .../registration/mapping_helpers.py | 21 ++ .../segmentation_postprocessing.py | 82 +++++++- CorpusCallosum/shape/cc_mesh.py | 18 +- CorpusCallosum/shape/cc_postprocessing.py | 180 +++++++++++++++++- CorpusCallosum/utils/utils.py | 12 ++ 6 files changed, 359 insertions(+), 41 deletions(-) create mode 100644 CorpusCallosum/utils/utils.py diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 5e79caaa..34dab18c 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -33,7 +33,12 @@ map_softlabels_to_orig, ) from CorpusCallosum.segmentation import segmentation_inference, segmentation_postprocessing -from CorpusCallosum.shape.cc_postprocessing import create_visualization, process_slices +from CorpusCallosum.shape.cc_postprocessing import ( + check_area_changes, + create_visualization, + make_subdivision_mask, + process_slices, +) from FastSurferCNN.data_loader.conform import is_conform from recon_surf import lta from recon_surf.align_points import find_rigid @@ -242,16 +247,10 @@ def options_parse() -> argparse.Namespace: args.output_dir = str(subject_dir_path) # Create parent directories for all output paths - for path in [ - args.upright_volume_path, - args.segmentation_path, - args.postproc_results_path, - args.cc_markers_path, - args.upright_lta_path, - args.orient_volume_lta_path, - ]: + for path_name in STANDARD_OUTPUT_PATHS.keys(): + path = getattr(args, f"{path_name}_path") if path is not None: - Path(path).parent.mkdir(parents=True, exist_ok=True) + Path(path).parent.mkdir(parents=False, exist_ok=True) return args @@ -397,6 +396,8 @@ def segment_cc(midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, sl return segmentation, outputs_soft + + def main( in_mri_path: str | Path, aseg_path: str | Path, @@ -408,6 +409,9 @@ def main( subdivisions: list[float] | None = None, subdivision_method: str = "shape", contour_smoothing: float = 1.0, + save_template: str | Path | None = None, + cpu: bool = False, + # output paths upright_volume_path: str | Path = None, segmentation_path: str | Path = None, postproc_results_path: str | Path = None, @@ -420,12 +424,10 @@ def main( vtk_file_path: str | Path = None, orig_space_segmentation_path: str | Path = None, debug_image_path: str | Path = None, - save_template: str | Path | None = None, thickness_image_path: str | Path = None, softlabels_cc_path: str | Path = None, softlabels_fn_path: str | Path = None, softlabels_background_path: str | Path = None, - cpu: bool = False, ) -> None: """Main pipeline function for corpus callosum analysis. @@ -448,6 +450,7 @@ def main( subdivisions: List of subdivision fractions for CC subsegmentation subdivision_method: Method for contour subdivision contour_smoothing: Gaussian sigma for smoothing during contour detection + cpu: Force CPU usage even when CUDA is available upright_volume_path: Path for upright volume output (default: output_dir/upright_volume.mgz) segmentation_path: Path for segmentation output (default: output_dir/segmentation.mgz) postproc_results_path: Path for postprocessing results (default: output_dir/cc_postproc_results.json) @@ -467,7 +470,7 @@ def main( softlabels_cc_path: Path for cc softlabels (default: output_dir/mri/callosum_seg_soft.mgz) softlabels_fn_path: Path for fornix softlabels (default: output_dir/mri/fornix_seg_soft.mgz) softlabels_background_path: Path for background softlabels (default: output_dir/mri/background_seg_soft.mgz) - cpu: Force CPU usage even when CUDA is available + The function saves multiple outputs to specified paths or default locations in output_dir: - cc_markers.json: Contains detected landmarks and measurements @@ -596,20 +599,6 @@ def main( save_nifti_background(IO_processes, outputs_soft[..., 2], seg_affine, orig.header, softlabels_fn_path) - # map soft labels to original space (in parallel because this takes a while) - IO_processes.append( - run_in_background( - map_softlabels_to_orig, - False, - outputs_soft, - orig_fsaverage_vox2vox, - orig, - slices_to_analyze, - orig_space_segmentation_path, - fsaverage_middle=FSAVERAGE_MIDDLE, - ) - ) - # Create a temporary segmentation image with proper affine for enhanced postprocessing temp_seg_affine = fsaverage_hires_affine @ np.linalg.inv(np.eye(4)) @@ -626,7 +615,6 @@ def main( subdivisions=subdivisions, subdivision_method=subdivision_method, contour_smoothing=contour_smoothing, - output_dir=output_dir, debug_image_path=debug_image_path, surf_file_path=surf_file_path, overlay_file_path=overlay_file_path, @@ -637,11 +625,37 @@ def main( verbose=verbose, save_template=save_template, ) + + + outer_contours = [slice_result['split_contours'][0] for slice_result in slice_results] + + if not check_area_changes(outer_contours, verbose=True): + logger.warning("Large area changes detected between consecutive slices, " + "this is likely due to a segmentation error.") + IO_processes.extend(slice_io_processes) # Get middle slice result for backward compatibility middle_slice_result = slice_results[len(slice_results) // 2] + subdivision_mask = make_subdivision_mask(segmentation.shape[1:], middle_slice_result['split_contours']) + + + # map soft labels to original space (in parallel because this takes a while) + IO_processes.append( + run_in_background( + map_softlabels_to_orig, + debug=True, + outputs_soft=outputs_soft, + orig_fsaverage_vox2vox=orig_fsaverage_vox2vox, + orig=orig, + slices_to_analyze=slices_to_analyze, + orig_space_segmentation_path=orig_space_segmentation_path, + fsaverage_middle=FSAVERAGE_MIDDLE, + subdivision_mask=subdivision_mask, + ) + ) + # Create enhanced output dictionary with all slice results per_slice_output_dict = { "slices": [ @@ -679,9 +693,20 @@ def main( ########## Save outputs ########## - cc_volume = segmentation_postprocessing.get_cc_volume( - desired_width_mm=5, cc_mask=segmentation == CC_LABEL, voxel_size=orig.header.get_zooms() - ) + if len(outer_contours) > 1: + cc_volume_old = segmentation_postprocessing.get_cc_volume( + desired_width_mm=5, + cc_mask=segmentation == CC_LABEL, + voxel_size=orig.header.get_zooms() + ) + cc_volume = segmentation_postprocessing.get_cc_volume_simpsons( + desired_width_mm=5, + cc_contours=outer_contours, + voxel_size=orig.header.get_zooms() + ) + else: + cc_volume = None + # Create backward compatible output_dict for existing pipeline using middle slice output_dict = { diff --git a/CorpusCallosum/registration/mapping_helpers.py b/CorpusCallosum/registration/mapping_helpers.py index f4c06719..a3861ccd 100644 --- a/CorpusCallosum/registration/mapping_helpers.py +++ b/CorpusCallosum/registration/mapping_helpers.py @@ -232,6 +232,7 @@ def map_softlabels_to_orig( slices_to_analyze, orig_space_segmentation_path=None, fsaverage_middle=128, + subdivision_mask=None, ): """ Maps soft labels back to original image space and applies post-processing. @@ -247,6 +248,7 @@ def map_softlabels_to_orig( Returns: segmentation_orig_space: Final segmentation in original image space """ + # map softlabels to original image softlabels_transformed = [] for i in range(outputs_soft.shape[-1]): @@ -277,6 +279,23 @@ def map_softlabels_to_orig( ) segmentation_orig_space = np.argmax(softlabels_orig_space, axis=-1) + + if subdivision_mask is not None: + # repeat subdivision mask for shape 0 of orig + subdivision_mask = np.repeat(subdivision_mask[np.newaxis, :, :], orig.shape[0], axis=0) + # map subdivision mask to orig space + subdivision_mask_orig_space = affine_transform( + subdivision_mask, + orig_fsaverage_vox2vox, + output_shape=orig.shape, + order=0, + ) + + segmentation_orig_space[segmentation_orig_space == 1] = \ + segmentation_orig_space[segmentation_orig_space == 1] * \ + subdivision_mask_orig_space[segmentation_orig_space == 1] + + segmentation_orig_space = np.where( segmentation_orig_space == 1, 192, segmentation_orig_space ) @@ -284,6 +303,8 @@ def map_softlabels_to_orig( segmentation_orig_space == 2, 250, segmentation_orig_space ) + + if orig_space_segmentation_path is not None: logger.info(f"Saving segmentation in original space to {orig_space_segmentation_path}") nib.save( diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index dfba0e0e..84988c3e 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -1,9 +1,14 @@ import numpy as np -from scipy import ndimage +from scipy import integrate, ndimage from skimage.measure import label +import FastSurferCNN.utils.logging as logging from CorpusCallosum.data.constants import CC_LABEL, FORNIX_LABEL +logger = logging.get_logger(__name__) + + + def get_cc_volume(desired_width_mm: int, cc_mask: np.ndarray, voxel_size: tuple[float, float, float]) -> float: """Calculate the volume of the corpus callosum in cubic millimeters. @@ -57,7 +62,80 @@ def get_cc_volume(desired_width_mm: int, cc_mask: np.ndarray, voxel_size: tuple[ else: raise ValueError(f"Width of CC segmentation is smaller than desired width: {width_mm} < {desired_width_mm}") - +def get_cc_volume_simpsons(desired_width_mm: int, cc_contours: list[np.ndarray], + voxel_size: tuple[float, float, float]) -> float: + """Calculate the volume of the corpus callosum in cubic millimeters using Simpson's rule. + + This function calculates the volume of the corpus callosum (CC) in cubic millimeters using Simpson's rule. + If the CC width is larger than desired_width_mm, the voxels on the edges are calculated as + partial volumes to achieve the desired width. + + Args: + desired_width_mm (int): Desired width of the CC in millimeters + cc_contours (list[np.ndarray]): List of CC contours for each slice in the left-right direction + voxel_size (tuple[float, float, float]): Voxel size in millimeters (x, y, z) + + Returns: + float: Volume of the CC in cubic millimeters + + Raises: + ValueError: If CC width is smaller than desired width or insufficient contours for Simpson's rule + """ + if len(cc_contours) < 3: + raise ValueError("Need at least 3 contours for Simpson's rule integration") + + # Calculate cross-sectional areas for each contour + areas = [] + for contour in cc_contours: + # Calculate area using the shoelace formula for polygon area + if contour.shape[1] < 3: + areas.append(0.0) + else: + x = contour[0] + y = contour[1] + # Shoelace formula: A = 0.5 * |sum(x_i * y_{i+1} - x_{i+1} * y_i)| + area = 0.5 * np.abs(np.sum(x[:-1] * y[1:] - x[1:] * y[:-1])) + # Convert from voxel^2 to mm^2 + area_mm2 = area * voxel_size[1] * voxel_size[2] # y * z voxel dimensions + areas.append(area_mm2) + + areas = np.array(areas) + + # Calculate spacing between slices (left-right direction) + lr_spacing = voxel_size[0] # x-direction voxel size + + # Get current width in mm + current_width_mm = len(cc_contours) * lr_spacing + + if current_width_mm == desired_width_mm: + # Use Simpson's rule directly + return integrate.simpson(areas, dx=lr_spacing) + elif current_width_mm > desired_width_mm: + # Handle partial volumes at edges + desired_width_vox = desired_width_mm / lr_spacing + fraction_of_voxel_at_edge = (desired_width_vox % 1) / 2 + + if fraction_of_voxel_at_edge > 0: + # Apply partial volume correction to edge areas + areas_corrected = areas.copy() + areas_corrected[0] *= fraction_of_voxel_at_edge + areas_corrected[-1] *= fraction_of_voxel_at_edge + + # Use Simpson's rule with corrected areas + return integrate.simps(areas_corrected, dx=lr_spacing) + else: + # No partial volumes needed, truncate to desired width + desired_slices = int(desired_width_vox) + if desired_slices % 2 == 0: + desired_slices += 1 # Ensure odd number for Simpson's rule + + start_idx = (len(areas) - desired_slices) // 2 + end_idx = start_idx + desired_slices + truncated_areas = areas[start_idx:end_idx] + + return integrate.simps(truncated_areas, dx=lr_spacing) + else: + raise ValueError(f"Width of CC segmentation is smaller than desired width: {current_width_mm} < {desired_width_mm}") def get_largest_cc(seg_arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Get largest connected component from a binary segmentation array. diff --git a/CorpusCallosum/shape/cc_mesh.py b/CorpusCallosum/shape/cc_mesh.py index f67a4577..86d1240a 100644 --- a/CorpusCallosum/shape/cc_mesh.py +++ b/CorpusCallosum/shape/cc_mesh.py @@ -389,6 +389,7 @@ def plot_mesh( ) if output_path is not None: + self.__make_parent_folder(output_path) fig.write_html(output_path) # Save as interactive HTML else: # For non-interactive display, save to a temporary HTML and open in browser @@ -718,6 +719,7 @@ def plot_contour(self, slice_idx: int, output_path: str): Raises: ValueError: If the contour for the specified slice is not set. """ + self.__make_parent_folder(output_path) if self.contours[slice_idx] is None: raise ValueError(f"Contour for slice {slice_idx} is not set") @@ -976,8 +978,10 @@ def plot_cc_contour_with_levelsets(self, contour_idx=0, levelpaths=None, title=N # plt.ylim(-105, -75) # plt.xlim(181, 101) if save_path is not None: + self.__make_parent_folder(save_path) plt.savefig(save_path, dpi=300) - plt.show() + else: + plt.show() return fig def set_mesh(self, vertices, faces, thickness_values=None): @@ -1047,6 +1051,7 @@ def snap_cc_picture(self, output_path: str, fssurf_file: str | None = None, over This method uses a temporary file to store the mesh and overlay data during the snapshot process. """ + self.__make_parent_folder(output_path) # Skip snapshot if there are no faces if len(self.t) == 0: print("Warning: Cannot create snapshot - no faces in mesh") @@ -1280,6 +1285,13 @@ def load_thickness_values(self, input_path: str, original_thickness_vertices_pat ] self.thickness_values = new_thickness_values + @staticmethod + def __make_parent_folder(filename: str): + """Make the parent folder of the given filename. + """ + output_folder = Path(filename).parent + output_folder.mkdir(parents=False, exist_ok=True) + def to_fs_coordinates(self): """Convert mesh coordinates to FreeSurfer coordinate system. @@ -1299,6 +1311,7 @@ def write_fssurf(self, filename): Returns: The result of the parent class's write_fssurf method. """ + self.__make_parent_folder(filename) return super().write_fssurf(filename) def write_overlay(self, filename): @@ -1310,6 +1323,7 @@ def write_overlay(self, filename): Returns: The result of writing the morph data using nibabel. """ + self.__make_parent_folder(filename) return nib.freesurfer.write_morph_data(filename, self.mesh_vertex_colors) def save_thickness_measurement_points(self, filename): @@ -1321,6 +1335,7 @@ def save_thickness_measurement_points(self, filename): Args: filename (str): Path where to save the CSV file. """ + self.__make_parent_folder(filename) logger.info(f"Saving thickness measurement points to CSV file: {filename}") with open(filename, "w") as f: f.write("slice_idx,vertex_idx\n") @@ -1340,6 +1355,7 @@ def _load_thickness_measurement_points(filename): list: List of arrays containing vertex indices for each slice where thickness was measured. """ + self.__make_parent_folder(filename) data = np.loadtxt(filename, delimiter=",", skiprows=1) slice_indices = data[:, 0].astype(int) vertex_indices = data[:, 1].astype(int) diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/cc_postprocessing.py index 761497dc..e2b68e99 100644 --- a/CorpusCallosum/shape/cc_postprocessing.py +++ b/CorpusCallosum/shape/cc_postprocessing.py @@ -14,7 +14,7 @@ from shape.cc_thickness import cc_thickness, convert_to_ras import FastSurferCNN.utils.logging as logging -from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE +from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE, SUBSEGEMNT_LABELS from CorpusCallosum.data.read_write import run_in_background from CorpusCallosum.utils.utils import HiddenPrints from CorpusCallosum.visualization.visualization import plot_contours @@ -219,7 +219,7 @@ def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thi def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac_coords, pc_coords, num_thickness_points, subdivisions, subdivision_method, contour_smoothing, - output_dir, debug_image_path=None, thickness_image_path=None, vox_size=None, + debug_image_path=None, thickness_image_path=None, vox_size=None, save_template=None, surf_file_path=None, overlay_file_path=None, cc_html_path=None, vtk_file_path=None, verbose=False): """Process corpus callosum slices based on selection mode. @@ -238,7 +238,6 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac subdivisions (list[float]): List of fractions for anatomical subdivisions subdivision_method (str): Method for contour subdivision contour_smoothing (float): Gaussian sigma for contour smoothing - output_dir (str): Base output directory debug_image_path (str, optional): Path for debug visualization image verbose (bool): Whether to print progress information save_template (str | Path | None): Directory path where to save template files, or None to skip saving @@ -254,7 +253,7 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac if slice_selection == "middle": cc_mesh = CC_Mesh(num_slices=1) cc_mesh.set_acpc_coords(ac_coords, pc_coords) - cc_mesh.set_resolution(1) # contour is always scaled to 1 mm + cc_mesh.set_resolution(vox_size) # contour is always scaled to 1 mm # Process only the middle slice slice_idx = segmentation.shape[0] // 2 @@ -287,7 +286,7 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac num_slices = segmentation.shape[0] cc_mesh = CC_Mesh(num_slices=num_slices) cc_mesh.set_acpc_coords(ac_coords, pc_coords) - cc_mesh.set_resolution(1) # contour is always scaled to 1 mm + cc_mesh.set_resolution(vox_size) # contour is always scaled to 1 mm # Process multiple slices or specific slice if slice_selection == "all": @@ -379,11 +378,178 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac if verbose: logger.info(f"Saving thickness image to {thickness_image_path}") with HiddenPrints(): - cc_mesh.snap_cc_picture(str(output_dir / thickness_image_path)) + cc_mesh.snap_cc_picture(str(thickness_image_path)) if not slice_results: logger.error("Error: No valid slices were found for postprocessing") exit(1) - return slice_results, IO_processes \ No newline at end of file + return slice_results, IO_processes + + + + +def vectorized_line_test(coords_x, coords_y, line_start, line_end): + """Vectorized version of point_relative_to_line for arrays of points. + + Args: + coords_x (np.ndarray): Array of x coordinates + coords_y (np.ndarray): Array of y coordinates + line_start (array-like): [x, y] coordinates of line start point + line_end (array-like): [x, y] coordinates of line end point + + Returns: + np.ndarray: Boolean array where True means point is to the left of the line + """ + # Vector from line_start to line_end + line_vec = np.array(line_end) - np.array(line_start) + + # Vectors from line_start to all points (vectorized) + point_vec_x = coords_x - line_start[0] + point_vec_y = coords_y - line_start[1] + + # Cross product (vectorized): positive means point is to the left of the line + cross_products = line_vec[0] * point_vec_y - line_vec[1] * point_vec_x + + return cross_products > 0 + + + + +def get_unique_contour_points(split_contours): + """Get unique contour points from the split contours. + This is a workaround to retrospectively add voxel-based sub-division + in the future we could keep track of the sub-division lines for + every sub-division scheme. + + Args: + split_contours (list): List of split contours (subsegmentations) + + Returns: + list: List of unique contour points + + """ + # For each contour point, check if it appears in other contours + unique_contour_points = [] + + for i, contour in enumerate(split_contours): + # Get points for this contour + contour_points = np.vstack((contour[0], -contour[1])).T # Shape: (N,2) + + # Check each point against all other contours + unique_points = [] + for point in contour_points: + is_unique = True + + # Compare against other contours + for j, other_contour in enumerate(split_contours): + if i == j: + continue + + other_points = np.vstack((other_contour[0], -other_contour[1])).T + + # Check if point exists in other contour (with small tolerance) + if np.any(np.all(np.abs(other_points - point) < 1e-6, axis=1)): + is_unique = False + break + + if is_unique: + unique_points.append(point) + + unique_contour_points.append(np.array(unique_points)) + + return unique_contour_points + + +def make_subdivision_mask(slice_shape, split_contours): + """Create a mask for subdividing the corpus callosum based on split contours. + + This function creates a mask that assigns different labels to different segments of the corpus callosum + based on the subdivision lines defined by the split contours. Each segment is labeled with a value from + SUBSEGEMNT_LABELS. + + Args: + slice_shape (tuple): Shape of the slice (rows, cols) + split_contours (list): List of contours defining the subdivisions. + Each contour is a tuple of x and y coordinates. + + Returns: + ndarray: A mask of shape slice_shape where each pixel is labeled with a value from SUBSEGEMNT_LABELS + indicating which subdivision segment it belongs to. + """ + + # unique contour points are the points where sub-division lines were inserted + unique_contour_points = get_unique_contour_points(split_contours) + subdivision_segments = unique_contour_points[1:] + + for s in subdivision_segments: + if len(s) != 2: + logger.error(f'Subdivision segment {s} has {len(s)} points, expected 2') + + # Create coordinate grids for all points in the slice + rows, cols = slice_shape + y_coords, x_coords = np.mgrid[0:rows, 0:cols] + + # Initialize with first segment label + subdivision_mask = np.full(slice_shape, SUBSEGEMNT_LABELS[0], dtype=np.int32) + + # Process each subdivision line + for segment_idx, segment_points in enumerate(subdivision_segments): + line_start = segment_points[0] + line_end = segment_points[-1] + + # Vectorized test: find all points to the right of this line + points_right_of_line = vectorized_line_test(x_coords, y_coords, line_start, line_end) + + # All points to the right of this line belong to the next segment or beyond + subdivision_mask[points_right_of_line] = SUBSEGEMNT_LABELS[segment_idx + 1] + + # Debug visualization (optional) + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots(figsize=(10, 8)) + # ax.imshow(subdivision_mask, cmap='tab10') + # ax.plot([line_start[0], line_end[0]], [line_start[1], line_end[1]], 'r-', linewidth=2) + # ax.set_title(f'After subdivision line {segment_idx}') + # plt.show() + + return subdivision_mask + + +def check_area_changes(contours: list[np.ndarray], threshold: float = 0.3, verbose: bool = False) -> None: + """Check for large changes between consecutive CC areas and issue warnings. + + This function checks if any two consecutive areas have a change greater than + the specified threshold (default 30%) and issues a warning if they do. + + Args: + contours (list[np.ndarray]): List of contours + threshold (float, optional): Threshold for relative change. Defaults to 0.3 (30%). + """ + + areas = [np.sum(np.sqrt(np.sum((np.diff(contour, axis=0))**2, axis=1))) for contour in contours] + + assert len(areas) > 1, "At least two areas are required to check for area changes" + + for i in range(len(areas) - 1): + if areas[i] == 0 and areas[i+1] == 0: + continue # Skip if both areas are zero + + if areas[i] == 0 or areas[i+1] == 0: + # One area is zero, the other is not - this is a 100% change + if verbose: + logger.warning(f"Large area change detected: area {i+1} = {areas[i]:.2f} mm², " + f"area {i+2} = {areas[i+1]:.2f} mm² (one area is zero)") + return False + + # Calculate relative change + relative_change = abs(areas[i+1] - areas[i]) / areas[i] + + if relative_change > threshold: + percent_change = relative_change * 100 + if verbose: + logger.warning(f"Large corpus callosum area change between slices detected: " + f"area {i+1} = {areas[i]:.2f} mm², " + f"area {i+2} = {areas[i+1]:.2f} mm² ({percent_change:.1f}% change)") + return False + return True \ No newline at end of file diff --git a/CorpusCallosum/utils/utils.py b/CorpusCallosum/utils/utils.py new file mode 100644 index 00000000..f5dd00fd --- /dev/null +++ b/CorpusCallosum/utils/utils.py @@ -0,0 +1,12 @@ +import os +import sys + + +class HiddenPrints: + def __enter__(self): + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, "w") + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout.close() + sys.stdout = self._original_stdout \ No newline at end of file From db2c70d767a10840229902e088c0caf543a685b7 Mon Sep 17 00:00:00 2001 From: ClePol Date: Thu, 25 Sep 2025 16:55:45 +0200 Subject: [PATCH 12/68] updated README and paths --- CorpusCallosum/README.md | 161 +++++++++++++++++++++++-------- CorpusCallosum/data/constants.py | 2 +- 2 files changed, 123 insertions(+), 40 deletions(-) diff --git a/CorpusCallosum/README.md b/CorpusCallosum/README.md index 4467ba02..57a59197 100644 --- a/CorpusCallosum/README.md +++ b/CorpusCallosum/README.md @@ -9,19 +9,16 @@ This pipeline combines localization and segmentation deep learning models to: 1. Detect AC (Anterior Commissure) and PC (Posterior Commissure) points 2. Extract and align midplane slices 3. Segment the corpus callosum -4. Perform post-processing for corpus callosum, including thickness analysis, and various shape metrics +4. Perform advanced morphometry for corpus callosum, including subdivision thickness analysis, and various shape metrics 5. Generate visualizations and measurements -## Directory Structure +## Quickstart -- `weights/` - Trained model weights -- `transforms/` - Image preprocessing transformations -- `shape/` - Shape analysis and post-processing tools -- `registration/` - Tools for image registration and alignment -- `data/` - Template data and IO -- `localization/` - Inference script for AC/PC localization -- `segmentation/` - Inference scripts for CC/FN segmentation +``` python3 fastsurfer_cc.py --subject_dir /path/to/fastsurfer/output --verbose ``` + +Gives all standard outputs. Then corpus callosum morphometry can be found at `stats/callosum.CC.midslice.json`, including 100 thickness measurements and areas of sub-segments. +Visualization will be placed in `/path/to/fastsurfer/output/qc_snapshots`. For more detailed info see the following sections. ## Command Line Interfaces @@ -36,7 +33,7 @@ The main pipeline script performs the complete corpus callosum analysis workflow python3 fastsurfer_cc.py --in_mri /path/to/input/mri.mgz --aseg /path/to/input/aseg.mgz --output_dir /path/to/output --verbose # Using FastSurfer/FreeSurfer subject directory structure -python3 fastsurfer_cc.py --subject_dir /path/to/freesurfer/subject --verbose +python3 fastsurfer_cc.py --subject_dir /path/to/fastsurfer/output --verbose ``` #### Required Arguments @@ -49,15 +46,16 @@ Choose one of these input methods: - `--output_dir PATH`: Directory for output files **Option 2: FastSurfer/FreeSurfer subject directory** -- `--subject_dir PATH`: Subject directory containing standard FreeSurfer structure +- `--subject_dir PATH`: Subject directory containing standard FastSurfer structure - Automatically uses `mri/orig.mgz` and `mri/aparc.DKTatlas+aseg.deep.mgz` - - Creates standard output paths in FreeSurfer structure + - Creates standard output paths in FastSurfer structure #### Optional Arguments **General Options:** - `--verbose`: Enable verbose output and debug plots - `--debug_output_dir PATH`: Directory for debug outputs +- `--cpu`: Force CPU usage even when CUDA is available **Shape Analysis Parameters:** - `--num_thickness_points INT`: Number of points for thickness estimation (default: 100) @@ -68,7 +66,7 @@ Choose one of these input methods: - `angular`: Subdivision based on equally spaced angles (Hampel et al.) - `eigenvector`: Primary direction (same as FreeSurfer's mri_cc) - `--contour_smoothing FLOAT`: Gaussian sigma for smoothing during contour detection (default: 1.0) -- `--slice_selection {middle,all,INT}`: Which slices to process (default: "middle") +- `--slice_selection {middle,all,INT}`: Which slices to process (default: "all") **Custom Output Paths:** - `--upright_volume_path PATH`: Path for upright volume output @@ -105,9 +103,109 @@ python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ --save_template /data/templates/sub001 ``` -### Visualization: `cc_visualization.py` +## Outputs + +The pipeline produces the following outputs in the specified output directory: + +### Main Pipeline Outputs + +**Analysis Results:** +- `stats/callosum.CC.midslice.json`: Contains detected landmarks and measurements for the middle slice +- `stats/callosum.CC.all_slices.json`: Enhanced postprocessing results with per-slice analysis + +**Transformation Matrices:** +- `mri/transforms/cc_up.lta`: Transformation from original to upright space (aligned to fsaverage, CC midslice at the center) +- `mri/transforms/orient_volume.lta`: Transformation a CC, AC & PC standardized space. The CC is at the center and AC & PC on the coordinate line, standardizing the head orientation. + +**Image Volumes:** +- `mri/callosum_seg_upright.mgz`: Corpus callosum segmentation in upright space (aligned to fsaverage, matching cc_up.lta) +- `mri/callosum_seg_aseg_space.mgz`: Corpus callosum segmentation in conformed image orientation (algined to orig.mgz and other segmentations) +- `mri/callosum_seg_soft.mgz`: Corpus callosum soft labels (segmentation probabilities, upright space) +- `mri/fornix_seg_soft.mgz`: Fornix soft labels (segmentation probabilities, upright space) +- `mri/background_seg_soft.mgz`: Background soft labels (segmentation probabilities, upright space) + + +**Quality Control and Visualizations:** +- `qc_snapshots/callosum.png`: Debug visualization of corpus callosum contours and thickness measurments +- `qc_snapshots/callosum_thickness.png`: 3D thickness visualization (when using `--slice_selection all`) +- `qc_snapshots/corpus_callosum.html`: Interactive 3D mesh visualization (when using `--slice_selection all`) + + +**Surface Files (only provided when using `--slice_selection all`):** +- `surf/callosum.surf`: FreeSurfer surface format for integration with FreeSurfer tools (e.g. freeview) +- `surf/callosum.thickness.w`: FreeSurfer overlay file containing thickness values +- `surf/callosum_mesh.vtk`: VTK format mesh file for 3D visualization + +### Template Files (when --save_template is used) + +- `contours.txt`: Corpus callosum contour coordinates for visualization +- `thickness_values.txt`: Thickness measurements at each contour point +- `measurement_points.txt`: Original vertex indices where thickness was measured + +## JSON Output Structure + +The pipeline generates two main JSON files with detailed measurements and analysis results: + +### `stats/callosum.CC.midslice.json` (Middle Slice Analysis) + +This file contains measurements from the middle sagittal slice and includes: + +**Shape Measurements (single values):** +- `total_area`: Total corpus callosum area (mm²) +- `total_perimeter`: Total perimeter length (mm) +- `circularity`: Shape circularity measure (4π × area / perimeter²) +- `cc_index`: Corpus callosum shape index (length/width ratio) +- `midline_length`: Length along the corpus callosum midline (mm) +- `curvature`: Average curve of the midline (degrees), measured by angle between it's sub-segements + +**Subdivisions** +- `areas`: Areas of CC using an improved Hofer-Frahm sub-division method (mm²). This gives more consistent sub-segemnts while preserving the original ratios. +- `areas_hofer_frahm`: Areas using classical Hofer-Frahm subdivision method (mm²) -Creates visualizations of corpus callosum from template files generated by the main pipeline. +**Thickness Analysis:** +- `thickness`: Average corpus callosum thickness (mm) +- `thickness_profile`: Thickness profile (mm) of the corpus callosum slice (100 thickness values by default, listed from anterior to posterior CC ends) + + +**Volume Measurements (when multiple slices processed):** +- `cc_5mm_volume`: Total CC volume within 5mm slab using voxel counting (mm³) +- `cc_5mm_volume_pv_corrected`: Volume with partial volume correction using CC contours (mm³) + +**Anatomical Landmarks:** +- `ac_center`: Anterior commissure coordinates in original image space +- `pc_center`: Posterior commissure coordinates in original image space +- `ac_center_oriented_volume`: AC coordinates in standardized space (orient_volume.lta) +- `pc_center_oriented_volume`: PC coordinates in standardized space (orient_volume.lta) +- `ac_center_upright`: AC coordinates in upright space (cc_up.lta) +- `pc_center_upright`: PC coordinates in upright space (cc_up.lta) + +**Processing Parameters:** +- `num_slices`: Number of slices analyzed around the midplane + +### `stats/callosum.CC.all_slices.json` (Multi-Slice Analysis) + +This file contains comprehensive per-slice analysis when using `--slice_selection all`: + +**Global Parameters:** +- `slices_in_segmentation`: Total number of slices in the segmentation volume +- `voxel_size`: Voxel dimensions [x, y, z] in mm +- `subdivision_method`: Method used for anatomical subdivision +- `num_thickness_points`: Number of points used for thickness estimation +- `subdivisions`: Subdivision fractions used for regional analysis +- `contour_smoothing`: Gaussian sigma used for contour smoothing +- `slice_selection`: Slice selection mode used + +**Per-Slice Data (`slices` array):** + +Each slice entry contains the shape measurements, thickness analysis and sub-divisions as described above. + + + + +## Visualization: `cc_visualization.py` + +Creates advanced visualizations of corpus callosum from template files generated by the main pipeline. +Useful for visualization of analysis results. #### Basic Usage @@ -210,7 +308,7 @@ python3 cc_visualization.py \ ### 2D Analysis and Visualization -When using `--slice_selection middle` (default) or a specific slice number with `--save_template`: +When using `--slice_selection middle` or a specific slice number with `--save_template`: ```bash # Generate 2D template data (middle slice) @@ -266,36 +364,21 @@ python3 cc_visualization.py \ - Focus is on mid-sagittal cross-sectional measurements - Compatibility with classical corpus callosum studies is needed -## Outputs +**Note:** The default behavior is `--slice_selection all` for comprehensive 3D analysis. Use `--slice_selection middle` to process only the middle slice for faster, traditional 2D analysis. -The pipeline produces the following outputs in the specified output directory: -### Main Pipeline Outputs -- `cc_markers.json`: Contains detected landmarks and measurements -- `cc_postproc_results.json`: Enhanced postprocessing results with per-slice analysis -- `orient_volume.lta`: Transformation matrix for orientation standardization (AC at origin, PC on anterior-posterior axis) -- `upright.lta`: Transformation matrix for midplane alignment (midsagittal plane cuts brain into hemispheres) -- `upright_volume.mgz`: Original volume mapped with upright.lta -- `segmentation.mgz`: Corpus callosum segmentation on midsagittal plane in upright_volume.mgz space -- `segmentation_orig_space.mgz`: Corpus callosum segmentation in original image orientation -- `cc_postprocessing.png`: Visualization of corpus callosum segmentation and thickness analysis +## Visualization Tool Outputs -### Template Files (when --save_template is used) - -- `contours.txt`: Corpus callosum contour coordinates -- `thickness_values.txt`: Thickness measurements at each point -- `measurement_points.txt`: Original vertex indices where thickness was measured +When using `cc_visualization.py`, additional outputs are generated (for advanced users). -### Visualization Outputs - -**3D Mode Outputs (default, when `--twoD` is not specified):** +**3D Mode Outputs (default):** - `cc_mesh.vtk`: VTK format mesh file for 3D visualization -- `cc_mesh.fssurf`: FreeSurfer surface format for integration with FreeSurfer tools -- `cc_mesh_overlay.curv`: FreeSurfer overlay file containing thickness values -- `cc_mesh.html`: Interactive 3D mesh visualization in HTML format +- `cc_mesh.fssurf`: FreeSurfer surface format +- `cc_mesh_overlay.curv`: FreeSurfer overlay file with thickness values +- `cc_mesh.html`: Interactive 3D mesh visualization - `cc_mesh_snap.png`: Snapshot image of the 3D mesh -- `midslice_2d.png`: 2D visualization of the middle slice contour with thickness +- `midslice_2d.png`: 2D visualization of the middle slice **2D Mode Outputs (when `--twoD` is specified):** - `cc_thickness_2d.png`: 2D contour visualization with thickness colormap \ No newline at end of file diff --git a/CorpusCallosum/data/constants.py b/CorpusCallosum/data/constants.py index b41afc7e..552e65e5 100644 --- a/CorpusCallosum/data/constants.py +++ b/CorpusCallosum/data/constants.py @@ -32,5 +32,5 @@ ## surface "surf_file": "surf/callosum.surf", # cc surface file "overlay_file": "surf/callosum.thickness.w", # cc surface overlay file - "vtk_file": "qc_snapshots/callosum_mesh.vtk", # vtk file of cc mesh + "vtk_file": "surf/callosum_mesh.vtk", # vtk file of cc mesh } \ No newline at end of file From 26239b0e2416ce824f49af8ce9fbc54ce9628dfe Mon Sep 17 00:00:00 2001 From: ClePol Date: Thu, 25 Sep 2025 16:56:27 +0200 Subject: [PATCH 13/68] added partial volume corrected volume calculation, error messages and small bugfixes --- CorpusCallosum/data/fsaverage_cc_template.py | 9 ++-- CorpusCallosum/fastsurfer_cc.py | 46 +++++++++++-------- .../segmentation_postprocessing.py | 43 ++++------------- CorpusCallosum/shape/cc_mesh.py | 16 +++++-- CorpusCallosum/shape/cc_postprocessing.py | 37 +++++++++------ 5 files changed, 76 insertions(+), 75 deletions(-) diff --git a/CorpusCallosum/data/fsaverage_cc_template.py b/CorpusCallosum/data/fsaverage_cc_template.py index 00736783..f38b11e8 100644 --- a/CorpusCallosum/data/fsaverage_cc_template.py +++ b/CorpusCallosum/data/fsaverage_cc_template.py @@ -47,13 +47,12 @@ def load_fsaverage_cc_template(): # smooth outside contour # Apply smoothing to the outside contour using a moving average - - freesurfer_home = Path(os.environ['FREESURFER_HOME']) - - if not freesurfer_home.exists(): + try: + freesurfer_home = Path(os.environ['FREESURFER_HOME']) + except KeyError as err: raise OSError(f"FREESURFER_HOME environment variable is not set correctly or does not exist: " f"{freesurfer_home}, either provide your own template or set the " - f"FREESURFER_HOME environment variable") + f"FREESURFER_HOME environment variable") from err fsaverage_seg_path = freesurfer_home / 'subjects' / 'fsaverage' / 'mri' / 'aparc+aseg.mgz' fsaverage_seg = nib.load(fsaverage_seg_path) diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 34dab18c..7f5021c5 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -466,7 +466,7 @@ def main( surf_file_path: Path for surf file (default: output_dir/surf/callosum.surf) overlay_file_path: Path for overlay file (default: output_dir/mri/callosum_seg_aseg_space.mgz) cc_html_path: Path for CC HTML file (default: output_dir/qc_snapshots/corpus_callosum.html) - vtk_file_path: Path for vtk file (default: output_dir/qc_snapshots/callosum_mesh.vtk) + vtk_file_path: Path for vtk file (default: output_dir/surf/callosum_mesh.vtk) softlabels_cc_path: Path for cc softlabels (default: output_dir/mri/callosum_seg_soft.mgz) softlabels_fn_path: Path for fornix softlabels (default: output_dir/mri/fornix_seg_soft.mgz) softlabels_background_path: Path for background softlabels (default: output_dir/mri/background_seg_soft.mgz) @@ -616,6 +616,7 @@ def main( subdivision_method=subdivision_method, contour_smoothing=contour_smoothing, debug_image_path=debug_image_path, + one_debug_image=True, surf_file_path=surf_file_path, overlay_file_path=overlay_file_path, cc_html_path=cc_html_path, @@ -629,7 +630,7 @@ def main( outer_contours = [slice_result['split_contours'][0] for slice_result in slice_results] - if not check_area_changes(outer_contours, verbose=True): + if len(outer_contours) > 1 and not check_area_changes(outer_contours, verbose=True): logger.warning("Large area changes detected between consecutive slices, " "this is likely due to a segmentation error.") @@ -638,14 +639,18 @@ def main( # Get middle slice result for backward compatibility middle_slice_result = slice_results[len(slice_results) // 2] - subdivision_mask = make_subdivision_mask(segmentation.shape[1:], middle_slice_result['split_contours']) + if len(middle_slice_result['split_contours']) <= 5: + subdivision_mask = make_subdivision_mask(segmentation.shape[1:], middle_slice_result['split_contours']) + else: + logger.warning("Too many subsegments for lookup table, skipping sub-divion of output segmentation.") + subdivision_mask = None # map soft labels to original space (in parallel because this takes a while) IO_processes.append( run_in_background( map_softlabels_to_orig, - debug=True, + debug=False, outputs_soft=outputs_soft, orig_fsaverage_vox2vox=orig_fsaverage_vox2vox, orig=orig, @@ -693,19 +698,7 @@ def main( ########## Save outputs ########## - if len(outer_contours) > 1: - cc_volume_old = segmentation_postprocessing.get_cc_volume( - desired_width_mm=5, - cc_mask=segmentation == CC_LABEL, - voxel_size=orig.header.get_zooms() - ) - cc_volume = segmentation_postprocessing.get_cc_volume_simpsons( - desired_width_mm=5, - cc_contours=outer_contours, - voxel_size=orig.header.get_zooms() - ) - else: - cc_volume = None + # Create backward compatible output_dict for existing pipeline using middle slice @@ -724,6 +717,24 @@ def main( "thickness_profile": middle_slice_result["thickness_profile"], } + if len(outer_contours) > 1: + cc_volume_voxel = segmentation_postprocessing.get_cc_volume_voxel( + desired_width_mm=5, + cc_mask=segmentation == CC_LABEL, + voxel_size=orig.header.get_zooms() + ) + cc_volume_contour = segmentation_postprocessing.get_cc_volume_contour( + desired_width_mm=5, + cc_contours=outer_contours, + voxel_size=orig.header.get_zooms() + ) + if verbose: + logger.info(f"CC volume voxel: {cc_volume_voxel}") + logger.info(f"CC volume contour: {cc_volume_contour}") + + output_dict["cc_5mm_volume"] = cc_volume_voxel + output_dict["cc_5mm_volume_pv_corrected"] = cc_volume_contour + # multiply split contour with resolution scale factor for middle slice visualization split_contours = [ split_contour * orig.header.get_zooms()[1] for split_contour in middle_slice_result["split_contours"] @@ -775,7 +786,6 @@ def main( output_dict["pc_center_oriented_volume"] = pc_coords_standardized output_dict["ac_center_upright"] = ac_coords_3d output_dict["pc_center_upright"] = pc_coords_3d - output_dict["cc_5mm_volume"] = cc_volume output_dict["num_slices"] = slices_to_analyze # Convert numpy arrays to lists for JSON serialization diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index 84988c3e..541295cf 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -10,7 +10,7 @@ -def get_cc_volume(desired_width_mm: int, cc_mask: np.ndarray, voxel_size: tuple[float, float, float]) -> float: +def get_cc_volume_voxel(desired_width_mm: int, cc_mask: np.ndarray, voxel_size: tuple[float, float, float]) -> float: """Calculate the volume of the corpus callosum in cubic millimeters. This function calculates the volume of the corpus callosum (CC) in cubic millimeters. @@ -62,7 +62,7 @@ def get_cc_volume(desired_width_mm: int, cc_mask: np.ndarray, voxel_size: tuple[ else: raise ValueError(f"Width of CC segmentation is smaller than desired width: {width_mm} < {desired_width_mm}") -def get_cc_volume_simpsons(desired_width_mm: int, cc_contours: list[np.ndarray], +def get_cc_volume_contour(desired_width_mm: int, cc_contours: list[np.ndarray], voxel_size: tuple[float, float, float]) -> float: """Calculate the volume of the corpus callosum in cubic millimeters using Simpson's rule. @@ -104,38 +104,15 @@ def get_cc_volume_simpsons(desired_width_mm: int, cc_contours: list[np.ndarray], # Calculate spacing between slices (left-right direction) lr_spacing = voxel_size[0] # x-direction voxel size - # Get current width in mm - current_width_mm = len(cc_contours) * lr_spacing + # interpolate areas at 0 and 5 + areas_interpolated = np.interp(x=[0, 5], xp=np.arange(lr_spacing/2, 5, lr_spacing), fp=areas) + + + + measurements = [0,0.5,1.5,2.5,3.5,4.5,5] + # can also use cumulative trapezoidal rule + return integrate.simpson([areas_interpolated[0]] + areas.tolist() + [areas_interpolated[1]], x=measurements) - if current_width_mm == desired_width_mm: - # Use Simpson's rule directly - return integrate.simpson(areas, dx=lr_spacing) - elif current_width_mm > desired_width_mm: - # Handle partial volumes at edges - desired_width_vox = desired_width_mm / lr_spacing - fraction_of_voxel_at_edge = (desired_width_vox % 1) / 2 - - if fraction_of_voxel_at_edge > 0: - # Apply partial volume correction to edge areas - areas_corrected = areas.copy() - areas_corrected[0] *= fraction_of_voxel_at_edge - areas_corrected[-1] *= fraction_of_voxel_at_edge - - # Use Simpson's rule with corrected areas - return integrate.simps(areas_corrected, dx=lr_spacing) - else: - # No partial volumes needed, truncate to desired width - desired_slices = int(desired_width_vox) - if desired_slices % 2 == 0: - desired_slices += 1 # Ensure odd number for Simpson's rule - - start_idx = (len(areas) - desired_slices) // 2 - end_idx = start_idx + desired_slices - truncated_areas = areas[start_idx:end_idx] - - return integrate.simps(truncated_areas, dx=lr_spacing) - else: - raise ValueError(f"Width of CC segmentation is smaller than desired width: {current_width_mm} < {desired_width_mm}") def get_largest_cc(seg_arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Get largest connected component from a binary segmentation array. diff --git a/CorpusCallosum/shape/cc_mesh.py b/CorpusCallosum/shape/cc_mesh.py index 86d1240a..e246aee5 100644 --- a/CorpusCallosum/shape/cc_mesh.py +++ b/CorpusCallosum/shape/cc_mesh.py @@ -1,3 +1,4 @@ +import sys import tempfile from pathlib import Path @@ -1280,9 +1281,17 @@ def load_thickness_values(self, input_path: str, original_thickness_vertices_pat [loaded_thickness_values[slice_idx], loaded_thickness_values[slice_idx][::-1]] ) else: - new_thickness_values[slice_idx][vertex_indices] = loaded_thickness_values[slice_idx][ - ~np.isnan(loaded_thickness_values[slice_idx]) - ] + try: + new_thickness_values[slice_idx][vertex_indices] = loaded_thickness_values[slice_idx][ + ~np.isnan(loaded_thickness_values[slice_idx])] + except IndexError: + print( + f"Tried to load " + f"{loaded_thickness_values[slice_idx][~np.isnan(loaded_thickness_values[slice_idx])]} " + f"values, but template has {new_thickness_values[slice_idx][vertex_indices]} values, " + "supply a correct template to visualize the thickness values" + ) + sys.exit(1) self.thickness_values = new_thickness_values @staticmethod @@ -1355,7 +1364,6 @@ def _load_thickness_measurement_points(filename): list: List of arrays containing vertex indices for each slice where thickness was measured. """ - self.__make_parent_folder(filename) data = np.loadtxt(filename, delimiter=",", skiprows=1) slice_indices = data[:, 0].astype(int) vertex_indices = data[:, 1].astype(int) diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/cc_postprocessing.py index e2b68e99..bc3a86c2 100644 --- a/CorpusCallosum/shape/cc_postprocessing.py +++ b/CorpusCallosum/shape/cc_postprocessing.py @@ -219,7 +219,8 @@ def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thi def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac_coords, pc_coords, num_thickness_points, subdivisions, subdivision_method, contour_smoothing, - debug_image_path=None, thickness_image_path=None, vox_size=None, + debug_image_path=None, one_debug_image=False, + thickness_image_path=None, vox_size=None, save_template=None, surf_file_path=None, overlay_file_path=None, cc_html_path=None, vtk_file_path=None, verbose=False): """Process corpus callosum slices based on selection mode. @@ -320,20 +321,26 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac if result is not None: slice_results.append(result) - - debug_path_base, debug_path_ext = str(debug_image_path).rsplit('.', 1) - debug_path_with_postfix = f"{debug_path_base}_slice_{slice_idx}" - debug_output_path_slice = Path(f"{debug_path_with_postfix}.{debug_path_ext}").with_suffix('.png') - - if verbose: - logger.info(f"Saving segmentation qc image to {debug_output_path_slice}") - - current_slice_in_volume = midslices.shape[0] // 2 - num_slices // 2 + slice_idx - # Create visualization for this slice - IO_processes.append(create_visualization(subdivision_method, result, - midslices[current_slice_in_volume:current_slice_in_volume+1], - debug_output_path_slice, ac_coords, pc_coords, - vox_size, f' (Slice {slice_idx})')) + + if (one_debug_image and slice_idx == num_slices // 2) or not one_debug_image: + if not one_debug_image: + debug_path_base, debug_path_ext = str(debug_image_path).rsplit('.', 1) + debug_path_with_postfix = f"{debug_path_base}_slice_{slice_idx}" + + debug_output_path_slice = Path(f"{debug_path_with_postfix}.{debug_path_ext}") + debug_output_path_slice = debug_output_path_slice.with_suffix('.png') + else: + debug_output_path_slice = debug_image_path + + if verbose: + logger.info(f"Saving segmentation qc image to {debug_output_path_slice}") + + current_slice_in_volume = midslices.shape[0] // 2 - num_slices // 2 + slice_idx + # Create visualization for this slice + IO_processes.append(create_visualization(subdivision_method, result, + midslices[current_slice_in_volume:current_slice_in_volume+1], + debug_output_path_slice, ac_coords, pc_coords, + vox_size, f' (Slice {slice_idx})')) if save_template is not None: # Convert to Path object and ensure directory exists From e097c0276abb2edf50c104af8353308f875f4a6d Mon Sep 17 00:00:00 2001 From: ClePol Date: Thu, 25 Sep 2025 23:23:34 +0200 Subject: [PATCH 14/68] sphinx doc build and license --- CorpusCallosum/cc_visualization.py | 9 +- CorpusCallosum/data/constants.py | 15 +++ CorpusCallosum/data/fsaverage_cc_template.py | 33 +++++-- .../data/generate_fsaverage_centroids.py | 13 +++ CorpusCallosum/data/read_write.py | 46 ++++++--- CorpusCallosum/fastsurfer_cc.py | 99 ++++++++++++------- .../localization/localization_inference.py | 14 +++ .../segmentation/segmentation_inference.py | 14 +++ .../segmentation_postprocessing.py | 14 +++ CorpusCallosum/shape/cc_endpoint_heuristic.py | 14 +++ CorpusCallosum/shape/cc_mesh.py | 93 +++++++++++------ CorpusCallosum/shape/cc_metrics.py | 14 +++ CorpusCallosum/shape/cc_postprocessing.py | 80 ++++++++++----- CorpusCallosum/shape/cc_subsegment_contour.py | 14 +++ CorpusCallosum/shape/cc_thickness.py | 14 +++ .../transforms/localization_transforms.py | 14 +++ .../transforms/segmentation_transforms.py | 14 +++ CorpusCallosum/utils/checkpoint.py | 2 +- CorpusCallosum/utils/utils.py | 14 +++ CorpusCallosum/visualization/visualization.py | 14 +++ doc/api/index.rst | 4 + doc/conf.py | 2 +- 22 files changed, 430 insertions(+), 120 deletions(-) diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index 65abf837..e8790592 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -76,10 +76,11 @@ def main( resolution: Resolution in mm for the mesh smooth_iterations: Number of smoothing iterations to apply to the mesh colormap: Which colormap to use. Options are: - - "red_to_blue": Red -> Orange -> Grey -> Light Blue -> Blue - - "blue_to_red": Blue -> Light Blue -> Grey -> Orange -> Red - - "red_to_yellow": Red -> Yellow -> Light Blue -> Blue - - "yellow_to_red": Yellow -> Light Blue -> Blue -> Red + - "red_to_blue": Red -> Orange -> Grey -> Light Blue -> Blue + - "blue_to_red": Blue -> Light Blue -> Grey -> Orange -> Red + - "red_to_yellow": Red -> Yellow -> Light Blue -> Blue + - "yellow_to_red": Yellow -> Light Blue -> Blue -> Red + color_range: Optional tuple of (min, max) to set fixed color range for the colorbar twoD: If True, generate 2D visualization instead of 3D mesh """ diff --git a/CorpusCallosum/data/constants.py b/CorpusCallosum/data/constants.py index 552e65e5..ba2ef46d 100644 --- a/CorpusCallosum/data/constants.py +++ b/CorpusCallosum/data/constants.py @@ -1,3 +1,18 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + from pathlib import Path ### Constants diff --git a/CorpusCallosum/data/fsaverage_cc_template.py b/CorpusCallosum/data/fsaverage_cc_template.py index f38b11e8..196873ed 100644 --- a/CorpusCallosum/data/fsaverage_cc_template.py +++ b/CorpusCallosum/data/fsaverage_cc_template.py @@ -1,27 +1,42 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from pathlib import Path import nibabel as nib import numpy as np from scipy import ndimage -from shape.cc_postprocessing import process_slice + +from CorpusCallosum.shape.cc_postprocessing import process_slice def smooth_contour(contour, window_size=5): """ - Smooth a contour using a moving average filter + Smooth a contour using a moving average filter. - Parameters: - ----------- + Parameters + ---------- contour : tuple of arrays - The contour coordinates (x, y) + The contour coordinates (x, y). window_size : int - Size of the smoothing window + Size of the smoothing window. - Returns: - -------- + Returns + ------- tuple of arrays - The smoothed contour coordinates (x, y) + The smoothed contour coordinates (x, y). """ x, y = contour diff --git a/CorpusCallosum/data/generate_fsaverage_centroids.py b/CorpusCallosum/data/generate_fsaverage_centroids.py index 2cc827d8..443a7fb3 100644 --- a/CorpusCallosum/data/generate_fsaverage_centroids.py +++ b/CorpusCallosum/data/generate_fsaverage_centroids.py @@ -1,4 +1,17 @@ #!/usr/bin/env python3 +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Script to generate static fsaverage centroids file. diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py index b22e804e..786b97ef 100644 --- a/CorpusCallosum/data/read_write.py +++ b/CorpusCallosum/data/read_write.py @@ -1,3 +1,17 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import multiprocessing import nibabel as nib @@ -10,20 +24,20 @@ def run_in_background(function, debug=False, *args, **kwargs): """Run a function in the background using multiprocessing. - + This function executes the given function either in a separate process (normal mode) or in the current process (debug mode). In debug mode, the function is executed synchronously for easier debugging. - + Args: function: The function to execute debug (bool): If True, run synchronously in current process - *args: Positional arguments to pass to the function - **kwargs: Keyword arguments to pass to the function - + args: Positional arguments to pass to the function + kwargs: Keyword arguments to pass to the function + Returns: multiprocessing.Process or None: Process object if running in background, - None if in debug mode + None if in debug mode """ if debug: function(*args, **kwargs) @@ -43,17 +57,19 @@ def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: list[int] | None Coordinates are returned in RAS (Right-Anterior-Superior) coordinate system. Args: - seg_img (nib.Nifti1Image): Nibabel image containing segmentation labels - label_ids (list[int] | None): Optional list of specific label IDs to process. + seg_img (nib.Nifti1Image) + Nibabel image containing segmentation labels + label_ids (list[int] | None) + Optional list of specific label IDs to process. If None, processes all non-zero labels. - + Returns: - If label_ids is None: - dict[int, np.ndarray]: Mapping of label IDs to their centroids (x,y,z) in RAS coordinates - If label_ids is provided: - tuple: Contains: - - dict[int, np.ndarray]: Mapping of found label IDs to their centroids - - list[int]: List of label IDs that were not found in the image + centroids (dict | dict, list) + If label_ids is None, returns a dict mapping label IDs to their centroids (x,y,z) in RAS coordinates. + If label_ids is provided, returns a tuple containing: + - dict[int, np.ndarray]: Mapping of found label IDs to their centroids + - list[int]: List of label IDs that were not found in the image + """ # Get segmentation data and affine seg_data = seg_img.get_fdata() diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 7f5021c5..be4aefd1 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -271,6 +271,7 @@ def centroid_registration(aseg_nib, verbose=False): - orig_fsaverage_ras2ras: Transformation matrix from original to fsaverage RAS space - fsaverage_hires_affine: High-resolution fsaverage affine matrix - fsaverage_header: FSAverage header fields for LTA writing + """ if verbose: print("Centroid registration") @@ -321,8 +322,9 @@ def localize_ac_pc(midslices, aseg_nib, orig_fsaverage_vox2vox, model_localizati Returns: tuple: Contains: - - ac_coords (np.ndarray): Coordinates of the anterior commissure - - pc_coords (np.ndarray): Coordinates of the posterior commissure + - ac_coords (np.ndarray): Coordinates of the anterior commissure + - pc_coords (np.ndarray): Coordinates of the posterior commissure + """ # get center of third ventricle from aseg and map to fsaverage space @@ -363,6 +365,7 @@ def segment_cc(midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, sl tuple: Contains: - segmentation (np.ndarray): Binary segmentation of the corpus callosum - outputs_soft (np.ndarray): Soft segmentation probabilities + """ # get 5 mm of slices output with 9 slices per inference midslices_middle = midslices.shape[0] // 2 @@ -439,39 +442,6 @@ def main( 5. Performs enhanced post-processing analysis 6. Saves results and visualizations - Args: - in_mri_path: Path to input MRI file - aseg_path: Path to input segmentation file - output_dir: Directory for output files - slice_selection: Which slices to process ('middle', 'all', or specific slice number) - debug_output_dir: Optional directory for debug outputs - verbose: Flag for verbose output - num_thickness_points: Number of points for thickness estimation - subdivisions: List of subdivision fractions for CC subsegmentation - subdivision_method: Method for contour subdivision - contour_smoothing: Gaussian sigma for smoothing during contour detection - cpu: Force CPU usage even when CUDA is available - upright_volume_path: Path for upright volume output (default: output_dir/upright_volume.mgz) - segmentation_path: Path for segmentation output (default: output_dir/segmentation.mgz) - postproc_results_path: Path for postprocessing results (default: output_dir/cc_postproc_results.json) - cc_markers_path: Path for CC markers output (default: output_dir/cc_markers.json) - upright_lta_path: Path for upright LTA transform (default: output_dir/upright.lta) - orient_volume_lta_path: Path for orientation volume LTA transform (default: output_dir/orient_volume.lta) - orig_space_segmentation_path: Path for segmentation in original space - (default: output_dir/mri/segmentation_orig_space.mgz) - debug_image_path: Path for debug visualization image (default: output_dir/stats/cc_postprocessing.png) - save_template: Directory path where to save contours.txt and thickness_values.txt files - thickness_image_path: Path for thickness image - (default: output_dir/qc_snapshots/corpus_callosum_thickness_3d.png) - surf_file_path: Path for surf file (default: output_dir/surf/callosum.surf) - overlay_file_path: Path for overlay file (default: output_dir/mri/callosum_seg_aseg_space.mgz) - cc_html_path: Path for CC HTML file (default: output_dir/qc_snapshots/corpus_callosum.html) - vtk_file_path: Path for vtk file (default: output_dir/surf/callosum_mesh.vtk) - softlabels_cc_path: Path for cc softlabels (default: output_dir/mri/callosum_seg_soft.mgz) - softlabels_fn_path: Path for fornix softlabels (default: output_dir/mri/fornix_seg_soft.mgz) - softlabels_background_path: Path for background softlabels (default: output_dir/mri/background_seg_soft.mgz) - - The function saves multiple outputs to specified paths or default locations in output_dir: - cc_markers.json: Contains detected landmarks and measurements - midplane_slices.mgz: Extracted midplane slices @@ -479,6 +449,65 @@ def main( - segmentation.mgz: Corpus callosum segmentation - cc_postproc_results.json: Enhanced postprocessing results - Various visualization plots and transformation matrices + + Args: + in_mri_path: + Path to input MRI file + aseg_path: + Path to input segmentation file + output_dir: + Directory for output files + slice_selection: + Which slices to process ('middle', 'all', or specific slice number) + debug_output_dir: + Optional directory for debug outputs + verbose: + Flag for verbose output + num_thickness_points: + Number of points for thickness estimation + subdivisions: + List of subdivision fractions for CC subsegmentation + subdivision_method: + Method for contour subdivision + contour_smoothing: + Gaussian sigma for smoothing during contour detection + save_template: + Directory path where to save contours.txt and thickness_values.txt files + cpu: + Force CPU usage even when CUDA is available + upright_volume_path: + Path for upright volume output (default: output_dir/upright_volume.mgz) + segmentation_path: + Path for segmentation output (default: output_dir/segmentation.mgz) + postproc_results_path: + Path for postprocessing results (default: output_dir/cc_postproc_results.json) + cc_markers_path: + Path for CC markers output (default: output_dir/cc_markers.json) + upright_lta_path: + Path for upright LTA transform (default: output_dir/upright.lta) + orient_volume_lta_path: + Path for orientation volume LTA transform (default: output_dir/orient_volume.lta) + surf_file_path: + Path for surf file (default: output_dir/surf/callosum.surf) + overlay_file_path: + Path for overlay file (default: output_dir/mri/callosum_seg_aseg_space.mgz) + cc_html_path: + Path for CC HTML file (default: output_dir/qc_snapshots/corpus_callosum.html) + vtk_file_path: + Path for vtk file (default: output_dir/surf/callosum_mesh.vtk) + orig_space_segmentation_path: + Path for segmentation in original space (default: output_dir/mri/segmentation_orig_space.mgz) + debug_image_path: + Path for debug visualization image (default: output_dir/stats/cc_postprocessing.png) + thickness_image_path: + Path for thickness image (default: output_dir/qc_snapshots/corpus_callosum_thickness_3d.png) + softlabels_cc_path: + Path for cc softlabels (default: output_dir/mri/callosum_seg_soft.mgz) + softlabels_fn_path: + Path for fornix softlabels (default: output_dir/mri/fornix_seg_soft.mgz) + softlabels_background_path: + Path for background softlabels (default: output_dir/mri/background_seg_soft.mgz) + """ if subdivisions is None: diff --git a/CorpusCallosum/localization/localization_inference.py b/CorpusCallosum/localization/localization_inference.py index db7cc84e..e90d1647 100644 --- a/CorpusCallosum/localization/localization_inference.py +++ b/CorpusCallosum/localization/localization_inference.py @@ -1,3 +1,17 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from pathlib import Path import numpy as np diff --git a/CorpusCallosum/segmentation/segmentation_inference.py b/CorpusCallosum/segmentation/segmentation_inference.py index 3334b49d..0e95ec4d 100644 --- a/CorpusCallosum/segmentation/segmentation_inference.py +++ b/CorpusCallosum/segmentation/segmentation_inference.py @@ -1,3 +1,17 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import nibabel as nib import numpy as np import torch diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index 541295cf..09f4cd7f 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -1,3 +1,17 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np from scipy import integrate, ndimage from skimage.measure import label diff --git a/CorpusCallosum/shape/cc_endpoint_heuristic.py b/CorpusCallosum/shape/cc_endpoint_heuristic.py index 868d07c2..10a9a775 100644 --- a/CorpusCallosum/shape/cc_endpoint_heuristic.py +++ b/CorpusCallosum/shape/cc_endpoint_heuristic.py @@ -1,3 +1,17 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import lapy import numpy as np import scipy.ndimage diff --git a/CorpusCallosum/shape/cc_mesh.py b/CorpusCallosum/shape/cc_mesh.py index e246aee5..702a46e2 100644 --- a/CorpusCallosum/shape/cc_mesh.py +++ b/CorpusCallosum/shape/cc_mesh.py @@ -1,3 +1,17 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import sys import tempfile from pathlib import Path @@ -11,10 +25,10 @@ import pyrr import scipy.interpolate from scipy.ndimage import gaussian_filter1d -from shape.cc_thickness import HiddenPrints, make_mesh_from_contour from whippersnappy.core import snap1 import FastSurferCNN.utils.logging as logging +from CorpusCallosum.shape.cc_thickness import HiddenPrints, make_mesh_from_contour logger = logging.get_logger(__name__) @@ -68,10 +82,14 @@ def add_contour( """Add a contour and its associated thickness values for a specific slice. Args: - slice_idx (int): Index of the slice where the contour should be added. - contour (numpy.ndarray): Array of shape (N, 2) containing 2D contour points. - thickness_values (numpy.ndarray): Array of thickness measurements for each contour point. - start_end_idx (tuple[int, int], optional): Tuple containing start and end indices for the contour. + slice_idx (int): + Index of the slice where the contour should be added. + contour (numpy.ndarray): + Array of shape (N, 2) containing 2D contour points. + thickness_values (numpy.ndarray): + Array of thickness measurements for each contour point. + start_end_idx (tuple[int, int], optional): + Tuple containing start and end indices for the contour. If None, defaults to (0, len(contour)//2). """ self.contours[slice_idx] = contour @@ -121,22 +139,32 @@ def plot_mesh( to an HTML file or displayed in a web browser. Args: - output_path (str, optional): Path to save the plot. If None, displays the plot interactively. - colormap (str, optional): Which colormap to use. Options are: + output_path (str, optional): + Path to save the plot. If None, displays the plot interactively. + colormap (str, optional): + Which colormap to use. Options are: - "red_to_blue": Red -> Orange -> Grey -> Light Blue -> Blue - "red_to_yellow": Red -> Yellow -> Light Blue -> Blue - "yellow_to_red": Yellow -> Light Blue -> Blue -> Red - "blue_to_red": Blue -> Light Blue -> Grey -> Orange -> Red + Defaults to "red_to_yellow". - thickness_overlay (bool, optional): Whether to overlay thickness values on the mesh. + thickness_overlay (bool, optional): + Whether to overlay thickness values on the mesh. Defaults to True. - show_contours (bool, optional): Whether to show the contours. Defaults to False. - show_grid (bool, optional): Whether to show the grid. Defaults to False. - color_range (tuple[float, float], optional): Optional tuple of (min, max) to set fixed + show_contours (bool, optional): + Whether to show the contours. Defaults to False. + show_grid (bool, optional): + Whether to show the grid. Defaults to False. + color_range (tuple[float, float], optional): + Optional tuple of (min, max) to set fixed color range. Defaults to None. - show_mesh_edges (bool, optional): Whether to show the mesh edges. Defaults to False. - legend (str, optional): Legend text for the colorbar. Defaults to "". - threshold (tuple[float, float], optional): Values between these thresholds will be shown in grey. + show_mesh_edges (bool, optional): + Whether to show the mesh edges. Defaults to False. + legend (str, optional): + Legend text for the colorbar. Defaults to "". + threshold (tuple[float, float], optional): + Values between these thresholds will be shown in grey. Defaults to (-0.2, 0.2). """ assert self.v is not None and self.t is not None, "Mesh has not been created yet" @@ -758,19 +786,19 @@ def plot_contour(self, slice_idx: int, output_path: str): def smooth_contour(self, contour_idx, window_size=5): """ - Smooth a contour using a moving average filter + Smooth a contour using a moving average filter. - Parameters: - ----------- + Parameters + ---------- contour : tuple of arrays - The contour coordinates (x, y) + The contour coordinates (x, y). window_size : int - Size of the smoothing window + Size of the smoothing window. - Returns: - -------- + Returns + ------- tuple of arrays - The smoothed contour coordinates (x, y) + The smoothed contour coordinates (x, y). """ x, y = self.contours[contour_idx].T @@ -1043,11 +1071,13 @@ def snap_cc_picture(self, output_path: str, fssurf_file: str | None = None, over overlay. The image is saved to the specified output path. Args: - output_path (str): Path where to save the snapshot image. - fssurf_file (str, optional): Path to a FreeSurfer surface file to use for the snapshot - if not provided, - the mesh is saved to a temporary file. - overlay_file (str, optional): Path to a FreeSurfer overlay file to use for the snapshot - if not provided, - the mesh is saved to a temporary file. + output_path (str): + Path where to save the snapshot image. + fssurf_file (str | None): Path to a FreeSurfer surface file to use for the snapshot. If None, + the mesh is saved to a temporary file. Defaults to None. + overlay_file (str | None): Path to a FreeSurfer overlay file to use for the snapshot. If None, + the mesh is saved to a temporary file. Defaults to None. + Note: This method uses a temporary file to store the mesh and overlay data during the snapshot process. @@ -1213,13 +1243,16 @@ def load_thickness_values(self, input_path: str, original_thickness_vertices_pat vertices using a measurement points file. Args: - input_path (str): Path to the CSV file containing thickness values. - original_thickness_vertices_path (str, optional): Path to a file containing the + input_path (str): + Path to the CSV file containing thickness values. + original_thickness_vertices_path (str, optional): + Path to a file containing the indices of vertices where thickness was measured. If None, assumes thickness values correspond to all vertices in order. Raises: - ValueError: If the number of thickness values doesn't match the number of + ValueError: + If the number of thickness values doesn't match the number of measurement points, or if the number of slices is inconsistent. """ data = np.loadtxt(input_path, delimiter=",", skiprows=1) diff --git a/CorpusCallosum/shape/cc_metrics.py b/CorpusCallosum/shape/cc_metrics.py index 96d7daca..f4921363 100644 --- a/CorpusCallosum/shape/cc_metrics.py +++ b/CorpusCallosum/shape/cc_metrics.py @@ -1,3 +1,17 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/cc_postprocessing.py index bc3a86c2..ff36cb47 100644 --- a/CorpusCallosum/shape/cc_postprocessing.py +++ b/CorpusCallosum/shape/cc_postprocessing.py @@ -1,21 +1,35 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from pathlib import Path import numpy as np -from shape.cc_endpoint_heuristic import get_endpoints -from shape.cc_mesh import CC_Mesh -from shape.cc_metrics import calculate_cc_index -from shape.cc_subsegment_contour import ( + +import FastSurferCNN.utils.logging as logging +from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE, SUBSEGEMNT_LABELS +from CorpusCallosum.data.read_write import run_in_background +from CorpusCallosum.shape.cc_endpoint_heuristic import get_endpoints +from CorpusCallosum.shape.cc_mesh import CC_Mesh +from CorpusCallosum.shape.cc_metrics import calculate_cc_index +from CorpusCallosum.shape.cc_subsegment_contour import ( get_primary_eigenvector, hampel_subdivide_contour, subdivide_contour, subsegment_midline_orthogonal, transform_to_acpc_standard, ) -from shape.cc_thickness import cc_thickness, convert_to_ras - -import FastSurferCNN.utils.logging as logging -from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE, SUBSEGEMNT_LABELS -from CorpusCallosum.data.read_write import run_in_background +from CorpusCallosum.shape.cc_thickness import cc_thickness, convert_to_ras from CorpusCallosum.utils.utils import HiddenPrints from CorpusCallosum.visualization.visualization import plot_contours @@ -101,19 +115,30 @@ def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thi - Subdivision into anatomical regions Args: - segmentation (np.ndarray): 3D segmentation array - slice_idx (int): Index of the slice to process - ac_coords (np.ndarray): Anterior commissure coordinates - pc_coords (np.ndarray): Posterior commissure coordinates - affine (np.ndarray): 4x4 affine transformation matrix - num_thickness_points (int): Number of points for thickness estimation - subdivisions (list[float]): List of fractions for anatomical subdivisions - subdivision_method (str): Method for contour subdivision ('shape', 'vertical', + segmentation (np.ndarray): + 3D segmentation array + slice_idx (int): + Index of the slice to process + ac_coords (np.ndarray): + Anterior commissure coordinates + pc_coords (np.ndarray): + Posterior commissure coordinates + affine (np.ndarray): + 4x4 affine transformation matrix + num_thickness_points (int): + Number of points for thickness estimation + subdivisions (list[float]): + List of fractions for anatomical subdivisions + subdivision_method (str): + Method for contour subdivision ('shape', 'vertical', 'angular', or 'eigenvector') - contour_smoothing (float): Gaussian sigma for contour smoothing + contour_smoothing (float): + Gaussian sigma for contour smoothing Returns: - dict or None: Dictionary containing measurements if successful, including: + slice_data (dict | None): + Dictionary containing measurements if successful, including: + - cc_index: Corpus callosum shape index - circularity: Shape circularity measure - areas: Areas of subdivided regions @@ -131,6 +156,7 @@ def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thi - slice_index: Index of the processed slice Returns None if no CC is found in the slice. + """ cc_mask_slice = segmentation[slice_idx] == 192 @@ -243,8 +269,9 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac verbose (bool): Whether to print progress information save_template (str | Path | None): Directory path where to save template files, or None to skip saving - Returns: - tuple: Contains: + Returns: + tuple: Contains: + - list: List of slice processing results - list: List of background IO processes """ @@ -477,13 +504,16 @@ def make_subdivision_mask(slice_shape, split_contours): SUBSEGEMNT_LABELS. Args: - slice_shape (tuple): Shape of the slice (rows, cols) - split_contours (list): List of contours defining the subdivisions. + slice_shape (tuple): + Shape of the slice (rows, cols) + split_contours (list): + List of contours defining the subdivisions. Each contour is a tuple of x and y coordinates. Returns: - ndarray: A mask of shape slice_shape where each pixel is labeled with a value from SUBSEGEMNT_LABELS - indicating which subdivision segment it belongs to. + ndarray: + A mask of shape slice_shape where each pixel is labeled with a value from SUBSEGEMNT_LABELS + indicating which subdivision segment it belongs to. """ # unique contour points are the points where sub-division lines were inserted diff --git a/CorpusCallosum/shape/cc_subsegment_contour.py b/CorpusCallosum/shape/cc_subsegment_contour.py index 0ba34e33..25bb89cc 100644 --- a/CorpusCallosum/shape/cc_subsegment_contour.py +++ b/CorpusCallosum/shape/cc_subsegment_contour.py @@ -1,3 +1,17 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np from scipy.spatial import ConvexHull diff --git a/CorpusCallosum/shape/cc_thickness.py b/CorpusCallosum/shape/cc_thickness.py index afa8fd31..05dcd55a 100644 --- a/CorpusCallosum/shape/cc_thickness.py +++ b/CorpusCallosum/shape/cc_thickness.py @@ -1,3 +1,17 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import meshpy.triangle as triangle import numpy as np import scipy.interpolate diff --git a/CorpusCallosum/transforms/localization_transforms.py b/CorpusCallosum/transforms/localization_transforms.py index 215e4817..2f106beb 100644 --- a/CorpusCallosum/transforms/localization_transforms.py +++ b/CorpusCallosum/transforms/localization_transforms.py @@ -1,3 +1,17 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np from monai.transforms import MapTransform, RandomizableTransform diff --git a/CorpusCallosum/transforms/segmentation_transforms.py b/CorpusCallosum/transforms/segmentation_transforms.py index ad7d7484..565b1dfe 100644 --- a/CorpusCallosum/transforms/segmentation_transforms.py +++ b/CorpusCallosum/transforms/segmentation_transforms.py @@ -1,3 +1,17 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np from monai.transforms import MapTransform, RandomizableTransform diff --git a/CorpusCallosum/utils/checkpoint.py b/CorpusCallosum/utils/checkpoint.py index 2fd19f21..355542bd 100644 --- a/CorpusCallosum/utils/checkpoint.py +++ b/CorpusCallosum/utils/checkpoint.py @@ -1,4 +1,4 @@ -# Copyright 2024 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/CorpusCallosum/utils/utils.py b/CorpusCallosum/utils/utils.py index f5dd00fd..4b79e8e6 100644 --- a/CorpusCallosum/utils/utils.py +++ b/CorpusCallosum/utils/utils.py @@ -1,3 +1,17 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import sys diff --git a/CorpusCallosum/visualization/visualization.py b/CorpusCallosum/visualization/visualization.py index 29b431da..bcb3b8af 100644 --- a/CorpusCallosum/visualization/visualization.py +++ b/CorpusCallosum/visualization/visualization.py @@ -1,3 +1,17 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from pathlib import Path import matplotlib.pyplot as plt diff --git a/doc/api/index.rst b/doc/api/index.rst index 546cdf4f..fb5492a8 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -16,6 +16,10 @@ FastSurfer API CerebNet.datasets.rst CerebNet.models.rst CerebNet.utils.rst + CorpusCallosum.rst + CorpusCallosum_data.rst + CorpusCallosum_shape.rst + CorpusCallosum_utils.rst HypVINN.rst HypVINN.dataloader.rst HypVINN.models.rst diff --git a/doc/conf.py b/doc/conf.py index faa0d8df..f9ea8d7b 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -227,7 +227,7 @@ def import_from_path(module_name, file_path): linkcode_resolve = LinkCodeResolver(gh_url, branch) -_re_script_dirs = "fastsurfercnn|cerebnet|recon_surf|hypvinn" +_re_script_dirs = "fastsurfercnn|cerebnet|recon_surf|hypvinn|corpuscallosum" _up = "^/\\.\\./" _end = "(\\.md)?(#.*)?$" From a9d29637a42eff648c2f8d54deeb2fcf089fcd71 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Fri, 26 Sep 2025 11:51:53 +0200 Subject: [PATCH 15/68] fix typos --- CorpusCallosum/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CorpusCallosum/README.md b/CorpusCallosum/README.md index 57a59197..3b4a53bb 100644 --- a/CorpusCallosum/README.md +++ b/CorpusCallosum/README.md @@ -9,7 +9,7 @@ This pipeline combines localization and segmentation deep learning models to: 1. Detect AC (Anterior Commissure) and PC (Posterior Commissure) points 2. Extract and align midplane slices 3. Segment the corpus callosum -4. Perform advanced morphometry for corpus callosum, including subdivision thickness analysis, and various shape metrics +4. Perform advanced morphometry for corpus callosum, including subdivision, thickness analysis, and various shape metrics 5. Generate visualizations and measurements @@ -119,14 +119,14 @@ The pipeline produces the following outputs in the specified output directory: **Image Volumes:** - `mri/callosum_seg_upright.mgz`: Corpus callosum segmentation in upright space (aligned to fsaverage, matching cc_up.lta) -- `mri/callosum_seg_aseg_space.mgz`: Corpus callosum segmentation in conformed image orientation (algined to orig.mgz and other segmentations) +- `mri/callosum_seg_aseg_space.mgz`: Corpus callosum segmentation in conformed image orientation (aligned to orig.mgz and other segmentations) - `mri/callosum_seg_soft.mgz`: Corpus callosum soft labels (segmentation probabilities, upright space) - `mri/fornix_seg_soft.mgz`: Fornix soft labels (segmentation probabilities, upright space) - `mri/background_seg_soft.mgz`: Background soft labels (segmentation probabilities, upright space) **Quality Control and Visualizations:** -- `qc_snapshots/callosum.png`: Debug visualization of corpus callosum contours and thickness measurments +- `qc_snapshots/callosum.png`: Debug visualization of corpus callosum contours and thickness measurements - `qc_snapshots/callosum_thickness.png`: 3D thickness visualization (when using `--slice_selection all`) - `qc_snapshots/corpus_callosum.html`: Interactive 3D mesh visualization (when using `--slice_selection all`) From fd6e8852452cbda4bb1a5a152ea0de332cf4a23c Mon Sep 17 00:00:00 2001 From: ClePol Date: Fri, 26 Sep 2025 12:12:35 +0200 Subject: [PATCH 16/68] added doc files for sphinx --- CorpusCallosum/__init__.py | 23 +++++++++++++++++++++++ CorpusCallosum/data/__init__.py | 0 CorpusCallosum/localization/__init__.py | 0 CorpusCallosum/segmentation/__init__.py | 0 CorpusCallosum/shape/__init__.py | 0 CorpusCallosum/transforms/__init__.py | 0 CorpusCallosum/utils/__init__.py | 0 CorpusCallosum/visualization/__init__.py | 0 doc/api/CorpusCallosum.rst | 11 +++++++++++ doc/api/CorpusCallosum_data.rst | 11 +++++++++++ doc/api/CorpusCallosum_shape.rst | 14 ++++++++++++++ doc/api/CorpusCallosum_utils.rst | 10 ++++++++++ 12 files changed, 69 insertions(+) create mode 100644 CorpusCallosum/__init__.py create mode 100644 CorpusCallosum/data/__init__.py create mode 100644 CorpusCallosum/localization/__init__.py create mode 100644 CorpusCallosum/segmentation/__init__.py create mode 100644 CorpusCallosum/shape/__init__.py create mode 100644 CorpusCallosum/transforms/__init__.py create mode 100644 CorpusCallosum/utils/__init__.py create mode 100644 CorpusCallosum/visualization/__init__.py create mode 100644 doc/api/CorpusCallosum.rst create mode 100644 doc/api/CorpusCallosum_data.rst create mode 100644 doc/api/CorpusCallosum_shape.rst create mode 100644 doc/api/CorpusCallosum_utils.rst diff --git a/CorpusCallosum/__init__.py b/CorpusCallosum/__init__.py new file mode 100644 index 00000000..100ab63a --- /dev/null +++ b/CorpusCallosum/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + "config", + "data", + "localization", + "segmentation", + "transforms", + "utils", + "visualization", +] diff --git a/CorpusCallosum/data/__init__.py b/CorpusCallosum/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CorpusCallosum/localization/__init__.py b/CorpusCallosum/localization/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CorpusCallosum/segmentation/__init__.py b/CorpusCallosum/segmentation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CorpusCallosum/shape/__init__.py b/CorpusCallosum/shape/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CorpusCallosum/transforms/__init__.py b/CorpusCallosum/transforms/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CorpusCallosum/utils/__init__.py b/CorpusCallosum/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CorpusCallosum/visualization/__init__.py b/CorpusCallosum/visualization/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/doc/api/CorpusCallosum.rst b/doc/api/CorpusCallosum.rst new file mode 100644 index 00000000..7d9152e5 --- /dev/null +++ b/doc/api/CorpusCallosum.rst @@ -0,0 +1,11 @@ +CorpusCallosum +============== + +.. currentmodule:: CorpusCallosum + +.. autosummary:: + :toctree: generated/ + + fastsurfer_cc + cc_visualization + paint_cc_into_pred diff --git a/doc/api/CorpusCallosum_data.rst b/doc/api/CorpusCallosum_data.rst new file mode 100644 index 00000000..a89128e2 --- /dev/null +++ b/doc/api/CorpusCallosum_data.rst @@ -0,0 +1,11 @@ +CorpusCallosum.data +=================== + +.. currentmodule:: CorpusCallosum.data + +.. autosummary:: + :toctree: generated/ + + constants + fsaverage_cc_template + read_write diff --git a/doc/api/CorpusCallosum_shape.rst b/doc/api/CorpusCallosum_shape.rst new file mode 100644 index 00000000..5fcccd7b --- /dev/null +++ b/doc/api/CorpusCallosum_shape.rst @@ -0,0 +1,14 @@ +CorpusCallosum.shape +==================== + +.. currentmodule:: CorpusCallosum.shape + +.. autosummary:: + :toctree: generated/ + + cc_postprocessing + cc_mesh + cc_metrics + cc_thickness + cc_subsegment_contour + cc_endpoint_heuristic diff --git a/doc/api/CorpusCallosum_utils.rst b/doc/api/CorpusCallosum_utils.rst new file mode 100644 index 00000000..bf06f78b --- /dev/null +++ b/doc/api/CorpusCallosum_utils.rst @@ -0,0 +1,10 @@ +CorpusCallosum.utils +==================== + +.. currentmodule:: CorpusCallosum.utils + +.. autosummary:: + :toctree: generated/ + + checkpoint + utils From d2084c2a3ef8038378e124aafaf697c0d2e8f513 Mon Sep 17 00:00:00 2001 From: ClePol Date: Mon, 29 Sep 2025 13:02:37 +0200 Subject: [PATCH 17/68] bugfixes for hires images, README updates, error handling --- CorpusCallosum/README.md | 2 +- CorpusCallosum/cc_visualization.py | 2 +- CorpusCallosum/fastsurfer_cc.py | 11 ++++---- .../segmentation_postprocessing.py | 26 ++++++++++++++----- CorpusCallosum/shape/cc_mesh.py | 21 ++++++++++----- CorpusCallosum/shape/cc_postprocessing.py | 15 +++++------ 6 files changed, 49 insertions(+), 28 deletions(-) diff --git a/CorpusCallosum/README.md b/CorpusCallosum/README.md index 3b4a53bb..2863569c 100644 --- a/CorpusCallosum/README.md +++ b/CorpusCallosum/README.md @@ -136,7 +136,7 @@ The pipeline produces the following outputs in the specified output directory: - `surf/callosum.thickness.w`: FreeSurfer overlay file containing thickness values - `surf/callosum_mesh.vtk`: VTK format mesh file for 3D visualization -### Template Files (when --save_template is used) +**Template Files (when --save_template is used):** - `contours.txt`: Corpus callosum contour coordinates for visualization - `thickness_values.txt`: Thickness measurements at each contour point diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index e8790592..be447050 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -128,7 +128,7 @@ def main( contour_idx=len(cc_mesh.contours) // 2, save_path=str(output_dir / "midslice_2d.png") ) - cc_mesh.to_fs_coordinates() + cc_mesh.to_fs_coordinates(vox_size=[resolution, resolution, resolution]) cc_mesh.write_vtk(str(output_dir / "cc_mesh.vtk")) cc_mesh.write_fssurf(str(output_dir / "cc_mesh.fssurf")) cc_mesh.write_overlay(str(output_dir / "cc_mesh_overlay.curv")) diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index be4aefd1..6fd8030e 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -532,8 +532,8 @@ def main( # Validate subdivision fractions for i in subdivisions: if i < 0 or i > 1: - print("Error: Subdivision fractions must be between 0 and 1, but got: ", i) - exit(1) + logger.error(f"Error: Subdivision fractions must be between 0 and 1, but got: {i}") + raise ValueError(f"Subdivision fractions must be between 0 and 1, but got: {i}") #### setup variables IO_processes = [] @@ -551,9 +551,9 @@ def main( "center around the mid-sagittal plane)" ) - if not is_conform(orig): + if not is_conform(orig, vox_size='min', img_size=None): logger.error("Error: MRI is not conformed, please run conform.py or mri_convert to conform the image.") - exit(1) + raise ValueError("MRI is not conformed, please run conform.py or mri_convert to conform the image.") # load models device = torch.device("cuda" if torch.cuda.is_available() and not cpu else "cpu") @@ -651,7 +651,8 @@ def main( cc_html_path=cc_html_path, vtk_file_path=vtk_file_path, thickness_image_path=thickness_image_path, - vox_size=orig.header.get_zooms()[0], + vox_size=orig.header.get_zooms(), + image_size=orig.shape, verbose=verbose, save_template=save_template, ) diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index 09f4cd7f..224c72a3 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -100,7 +100,11 @@ def get_cc_volume_contour(desired_width_mm: int, cc_contours: list[np.ndarray], # Calculate cross-sectional areas for each contour areas = [] + for contour in cc_contours: + contour = contour.copy() + assert voxel_size[1] == voxel_size[2], "volume must be isotropic" + contour *= voxel_size[1] # Calculate area using the shoelace formula for polygon area if contour.shape[1] < 3: areas.append(0.0) @@ -117,15 +121,25 @@ def get_cc_volume_contour(desired_width_mm: int, cc_contours: list[np.ndarray], # Calculate spacing between slices (left-right direction) lr_spacing = voxel_size[0] # x-direction voxel size + + measurement_points = np.arange(-voxel_size[0]*(areas.shape[0]//2), + voxel_size[0]*((areas.shape[0]+1)//2), lr_spacing) - # interpolate areas at 0 and 5 - areas_interpolated = np.interp(x=[0, 5], xp=np.arange(lr_spacing/2, 5, lr_spacing), fp=areas) + # interpolate areas at 0.25 and 5 + areas_interpolated = np.interp(x=[-2.5, 2.5], + xp=measurement_points, + fp=areas) + # remove measurement points that are outside of the desired range + # not sure if this can happen, but let's be safe + outside_range = (measurement_points < -2.5) | (measurement_points > 2.5) + measurement_points = [-2.5] + measurement_points[~outside_range].tolist() + [2.5] + areas = [areas_interpolated[0]] + areas[~outside_range].tolist() + [areas_interpolated[1]] + - measurements = [0,0.5,1.5,2.5,3.5,4.5,5] - # can also use cumulative trapezoidal rule - return integrate.simpson([areas_interpolated[0]] + areas.tolist() + [areas_interpolated[1]], x=measurements) + # can also use trapezoidal rule + return integrate.simpson(areas, x=measurement_points) def get_largest_cc(seg_arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]: @@ -192,7 +206,7 @@ def clean_cc_segmentation(seg_arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]: if 250 not in unique_labels: clean_seg[seg_arr == 250] = 250 - mask [seg_arr == 250] = True + mask[seg_arr == 250] = True if 192 not in unique_labels: clean_seg[seg_arr == 192] = 192 mask[seg_arr == 192] = True diff --git a/CorpusCallosum/shape/cc_mesh.py b/CorpusCallosum/shape/cc_mesh.py index 702a46e2..fd78ae2e 100644 --- a/CorpusCallosum/shape/cc_mesh.py +++ b/CorpusCallosum/shape/cc_mesh.py @@ -1317,14 +1317,19 @@ def load_thickness_values(self, input_path: str, original_thickness_vertices_pat try: new_thickness_values[slice_idx][vertex_indices] = loaded_thickness_values[slice_idx][ ~np.isnan(loaded_thickness_values[slice_idx])] - except IndexError: - print( + except IndexError as err: + logger.error( f"Tried to load " f"{loaded_thickness_values[slice_idx][~np.isnan(loaded_thickness_values[slice_idx])]} " f"values, but template has {new_thickness_values[slice_idx][vertex_indices]} values, " "supply a correct template to visualize the thickness values" ) - sys.exit(1) + raise ValueError( + f"Tried to load " + f"{loaded_thickness_values[slice_idx][~np.isnan(loaded_thickness_values[slice_idx])]} " + f"values, but template has {new_thickness_values[slice_idx][vertex_indices]} values, " + "supply a correct template to visualize the thickness values" + ) from err self.thickness_values = new_thickness_values @staticmethod @@ -1334,15 +1339,17 @@ def __make_parent_folder(filename: str): output_folder = Path(filename).parent output_folder.mkdir(parents=False, exist_ok=True) - def to_fs_coordinates(self): + def to_fs_coordinates(self, vox_size: tuple[int, int, int], image_size: tuple[int, int, int]): """Convert mesh coordinates to FreeSurfer coordinate system. Transforms the mesh vertices from the original coordinate system to the FreeSurfer coordinate system by reordering axes and applying appropriate offsets. """ - self.v = self.v[:, [2, 0, 1]] - self.v[:, 1] -= 128 - self.v[:, 2] += 128 + self.v = self.v[:, [2, 0, 1]] # LIA to ALI? + self.v *= (vox_size[0] **2) ## ??? + self.v[:, 1] -= image_size[1] * vox_size[1] // 2 # move 0 to center of image + self.v[:, 2] += image_size[2] * vox_size[2] // 2 + def write_fssurf(self, filename): """Write the mesh to a FreeSurfer surface file. diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/cc_postprocessing.py index ff36cb47..99e26350 100644 --- a/CorpusCallosum/shape/cc_postprocessing.py +++ b/CorpusCallosum/shape/cc_postprocessing.py @@ -134,7 +134,6 @@ def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thi 'angular', or 'eigenvector') contour_smoothing (float): Gaussian sigma for contour smoothing - Returns: slice_data (dict | None): Dictionary containing measurements if successful, including: @@ -246,7 +245,7 @@ def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thi def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac_coords, pc_coords, num_thickness_points, subdivisions, subdivision_method, contour_smoothing, debug_image_path=None, one_debug_image=False, - thickness_image_path=None, vox_size=None, + thickness_image_path=None, vox_size=None, image_size=None, save_template=None, surf_file_path=None, overlay_file_path=None, cc_html_path=None, vtk_file_path=None, verbose=False): """Process corpus callosum slices based on selection mode. @@ -281,7 +280,7 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac if slice_selection == "middle": cc_mesh = CC_Mesh(num_slices=1) cc_mesh.set_acpc_coords(ac_coords, pc_coords) - cc_mesh.set_resolution(vox_size) # contour is always scaled to 1 mm + cc_mesh.set_resolution(vox_size[0]) # contour is always scaled to 1 mm # Process only the middle slice slice_idx = segmentation.shape[0] // 2 @@ -309,12 +308,12 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac if verbose: logger.info(f"Saving segmentation qc image to {debug_image_path}") IO_processes.append(create_visualization(subdivision_method, result, midslices, - debug_image_path, ac_coords, pc_coords, vox_size)) + debug_image_path, ac_coords, pc_coords, vox_size[0])) else: num_slices = segmentation.shape[0] cc_mesh = CC_Mesh(num_slices=num_slices) cc_mesh.set_acpc_coords(ac_coords, pc_coords) - cc_mesh.set_resolution(vox_size) # contour is always scaled to 1 mm + cc_mesh.set_resolution(vox_size[0]) # contour is always scaled to 1 mm # Process multiple slices or specific slice if slice_selection == "all": @@ -367,7 +366,7 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac IO_processes.append(create_visualization(subdivision_method, result, midslices[current_slice_in_volume:current_slice_in_volume+1], debug_output_path_slice, ac_coords, pc_coords, - vox_size, f' (Slice {slice_idx})')) + vox_size[0], f' (Slice {slice_idx})')) if save_template is not None: # Convert to Path object and ensure directory exists @@ -394,7 +393,7 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac #cc_mesh.write_vtk(str(output_dir / 'cc_mesh.vtk')) - cc_mesh.to_fs_coordinates() + cc_mesh.to_fs_coordinates(vox_size=vox_size, image_size=image_size) if overlay_file_path is not None: if verbose: @@ -417,7 +416,7 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac if not slice_results: logger.error("Error: No valid slices were found for postprocessing") - exit(1) + raise ValueError("No valid slices were found for postprocessing") return slice_results, IO_processes From 64f5ec62d8ecb7860d7693a98c45f5f43180d209 Mon Sep 17 00:00:00 2001 From: ClePol Date: Mon, 29 Sep 2025 16:41:10 +0200 Subject: [PATCH 18/68] improved contour extraction for thin CC and surface coordinates --- CorpusCallosum/fastsurfer_cc.py | 9 +- .../segmentation_postprocessing.py | 288 ++++++++++++++++-- CorpusCallosum/shape/cc_endpoint_heuristic.py | 120 +++++++- CorpusCallosum/shape/cc_mesh.py | 24 +- CorpusCallosum/shape/cc_thickness.py | 2 +- 5 files changed, 378 insertions(+), 65 deletions(-) diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 6fd8030e..2a3ab95c 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -100,8 +100,8 @@ def options_parse() -> argparse.Namespace: parser.add_argument( "--contour_smoothing", type=float, - default=1.0, - help="Gaussian sigma for smoothing during contour detection. Default is 1.0, higher values mean a smoother" + default=5, + help="Window size for smoothing during contour detection. Default is 5, higher values mean a smoother" "outline, at the cost of precision.", ) parser.add_argument( @@ -390,7 +390,7 @@ def segment_cc(midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, sl or np.any(cc_volume_mask[:, :, 0]) or np.any(cc_volume_mask[:, :, -1]) ): - print("Warning: CC volume mask touches the edge of the segmentation field-of-view, CC might be truncated") + print("Warning: CC voume mask touches the edge of the segmentation field-of-view, CC might be truncated") # get voxels that were removed during cleaning removed_voxels = pre_clean_segmentation != segmentation @@ -411,7 +411,7 @@ def main( num_thickness_points: int = 100, subdivisions: list[float] | None = None, subdivision_method: str = "shape", - contour_smoothing: float = 1.0, + contour_smoothing: float = 5, save_template: str | Path | None = None, cpu: bool = False, # output paths @@ -606,7 +606,6 @@ def main( midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, slices_to_analyze ) - # calculate affine for segmentation volume orig_to_seg = np.eye(4) orig_to_seg[0, 3] = -FSAVERAGE_MIDDLE + slices_to_analyze // 2 diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index 224c72a3..3fa49738 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -14,6 +14,7 @@ import numpy as np from scipy import integrate, ndimage +from scipy.spatial.distance import cdist from skimage.measure import label import FastSurferCNN.utils.logging as logging @@ -22,6 +23,205 @@ logger = logging.get_logger(__name__) +def find_component_boundaries(labels_arr: np.ndarray, component_id: int) -> np.ndarray: + """Find boundary voxels of a connected component. + + Args: + labels_arr (np.ndarray): Labeled array from connected components analysis + component_id (int): ID of the component to find boundaries for + + Returns: + np.ndarray: Array of boundary coordinates (N, 3) + """ + component_mask = labels_arr == component_id + + # Create a structuring element for 6-connectivity (face neighbors only) + struct = ndimage.generate_binary_structure(3, 1) + + # Erode the component to find internal voxels + eroded = ndimage.binary_erosion(component_mask, structure=struct) + + # Boundary is the difference between original and eroded + boundary = component_mask & ~eroded + + return np.array(np.where(boundary)).T + + +def find_minimal_connection_path(boundary1: np.ndarray, boundary2: np.ndarray, + max_distance: float = 3.0) -> tuple[np.ndarray, np.ndarray] | None: + """Find the minimal connection path between two component boundaries. + + Args: + boundary1 (np.ndarray): Boundary coordinates of first component (N1, 3) + boundary2 (np.ndarray): Boundary coordinates of second component (N2, 3) + max_distance (float): Maximum distance to consider for connection + + Returns: + tuple | None: (point1, point2) coordinates of closest points if within max_distance, None otherwise + """ + if len(boundary1) == 0 or len(boundary2) == 0: + return None + + # Calculate pairwise distances between all boundary points + distances = cdist(boundary1, boundary2, metric='euclidean') + + # Find the minimum distance and corresponding points + min_idx = np.unravel_index(np.argmin(distances), distances.shape) + min_distance = distances[min_idx] + + if min_distance <= max_distance: + point1 = boundary1[min_idx[0]] + point2 = boundary2[min_idx[1]] + return point1, point2 + + return None + + +def create_connection_line(point1: np.ndarray, point2: np.ndarray) -> list[tuple[int, int, int]]: + """Create a line of voxels connecting two points using simplified 3D line algorithm. + + Args: + point1 (np.ndarray): Starting point coordinates (3,) + point2 (np.ndarray): Ending point coordinates (3,) + + Returns: + list: List of (x, y, z) coordinates forming the connection line + """ + x1, y1, z1 = map(int, point1) + x2, y2, z2 = map(int, point2) + + line_points = [] + + # Calculate the number of steps needed + dx = abs(x2 - x1) + dy = abs(y2 - y1) + dz = abs(z2 - z1) + + steps = max(dx, dy, dz) + + if steps == 0: + return [(x1, y1, z1)] + + # Calculate increments for each dimension + x_inc = (x2 - x1) / steps + y_inc = (y2 - y1) / steps + z_inc = (z2 - z1) / steps + + # Generate points along the line + for i in range(steps + 1): + x = int(round(x1 + i * x_inc)) + y = int(round(y1 + i * y_inc)) + z = int(round(z1 + i * z_inc)) + line_points.append((x, y, z)) + + return line_points + + +def connect_nearby_components(seg_arr: np.ndarray, max_connection_distance: float = 3.0) -> np.ndarray: + """Connect nearby disconnected components that should be connected. + + This function identifies disconnected components in the segmentation and creates + minimal connections between components that are close to each other. + + Args: + seg_arr (np.ndarray): Input binary segmentation array + max_connection_distance (float): Maximum distance to connect components + + Returns: + np.ndarray: Segmentation array with minimal connections added + """ + + # Create a copy to modify + connected_seg = seg_arr.copy() + + # Find connected components without dilation first + labels_cc = label(seg_arr, connectivity=3, background=0) + + # Get component sizes (excluding background) + bincount = np.bincount(labels_cc.flat) + component_ids = np.where(bincount > 0)[0][1:] # Exclude background (0) + + if len(component_ids) <= 1: + return connected_seg # Only one component, no connections needed + + # Sort components by size (largest first) + component_sizes = [(comp_id, bincount[comp_id]) for comp_id in component_ids] + component_sizes.sort(key=lambda x: x[1], reverse=True) + + # Use the largest component as the reference + main_component_id = component_sizes[0][0] + + + + logger.info(f"Found {len(component_ids)} disconnected components. " + f"Attempting to connect smaller components to main component (size: {component_sizes[0][1]})") + + connections_made = 0 + + # Try to connect each smaller component to the main component + for comp_id, comp_size in component_sizes[1:]: + if comp_size < 5: # Skip very small components (likely noise) + logger.debug(f"Skipping tiny component {comp_id} with size {comp_size}") + continue + + # Find boundaries of both components + main_boundary = find_component_boundaries(labels_cc, main_component_id) + comp_boundary = find_component_boundaries(labels_cc, comp_id) + + # Find minimal connection path + connection = find_minimal_connection_path(main_boundary, comp_boundary, max_connection_distance) + + if connection is not None: + point1, point2 = connection + distance = np.linalg.norm(point2 - point1) + + logger.debug(f"Connecting component {comp_id} (size: {comp_size}) to main component. " + f"Distance: {distance:.2f} voxels") + + # Create connection line + connection_line = create_connection_line(point1, point2) + + # Add connection voxels to the segmentation + # Use the same label as the original segmentation at the connection points + connection_label = seg_arr[point1[0], point1[1], point1[2]] if \ + seg_arr[point1[0], point1[1], point1[2]] != 0 else \ + seg_arr[point2[0], point2[1], point2[2]] + + for x, y, z in connection_line: + if (0 <= x < connected_seg.shape[0] and + 0 <= y < connected_seg.shape[1] and + 0 <= z < connected_seg.shape[2]): + if connected_seg[x, y, z] == 0: # Only fill empty voxels + connected_seg[x, y, z] = connection_label + + connections_made += 1 + else: + logger.debug(f"Component {comp_id} (size: {comp_size}) too far from main component") + + logger.info(f"Created {connections_made} minimal connections between components") + + + # Plot components for visualization + # import matplotlib.pyplot as plt + # n_components = len(component_sizes) + # fig, axes = plt.subplots(1, n_components + 1, figsize=(5*(n_components + 1), 5)) + # if n_components == 1: + # axes = [axes] + # # Plot each component in a different color + # for i, (comp_id, comp_size) in enumerate(component_sizes): + # component_mask = labels_cc == comp_id + # axes[i].imshow(component_mask[component_mask.shape[0]//2], cmap='gray') + # axes[i].set_title(f'Component {comp_id}\nSize: {comp_size}') + # axes[i].axis('off') + + # # Plot the connected segmentation + # axes[-1].imshow(connected_seg[connected_seg.shape[0]//2], cmap='gray') + # axes[-1].set_title('Connected Segmentation') + # axes[-1].axis('off') + # plt.tight_layout() + # plt.show() + + return connected_seg def get_cc_volume_voxel(desired_width_mm: int, cc_mask: np.ndarray, voxel_size: tuple[float, float, float]) -> float: @@ -142,48 +342,69 @@ def get_cc_volume_contour(desired_width_mm: int, cc_contours: list[np.ndarray], return integrate.simpson(areas, x=measurement_points) -def get_largest_cc(seg_arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - """Get largest connected component from a binary segmentation array. +def get_largest_cc(seg_arr: np.ndarray, max_connection_distance: float = 3.0) -> tuple[np.ndarray, np.ndarray]: + """Get largest connected component from a binary segmentation array with minimal connections. - This function takes a binary segmentation array, dilates it, finds connected components, - and returns the largest component (excluding background) along with its mask. + This function takes a binary segmentation array, attempts to connect nearby disconnected + components that should be connected, then finds the largest connected component. + It first tries to establish minimal connections between close components before + falling back to dilation if no connections are made. Args: seg_arr (np.ndarray): Input binary segmentation array + max_connection_distance (float): Maximum distance to connect components (default: 3.0) Returns: tuple: A tuple containing: - clean_seg (np.ndarray): Segmentation array with only the largest connected component - largest_cc (np.ndarray): Binary mask of the largest connected component """ - # generate dilatation structure - struct1 = ndimage.generate_binary_structure(3, 3) - # Dilate prediction - mask = ndimage.binary_dilation(seg_arr, structure=struct1, iterations=1, ).astype(np.uint8) - # Get connected components + # First attempt: try to connect nearby components with minimal connections + connected_seg = connect_nearby_components(seg_arr, max_connection_distance) + + # Check if connections were successful by comparing connectivity + original_labels = label(seg_arr, connectivity=3, background=0) + connected_labels = label(connected_seg, connectivity=3, background=0) + + original_components = len(np.unique(original_labels)) - 1 # Exclude background + connected_components = len(np.unique(connected_labels)) - 1 # Exclude background + + if connected_components < original_components: + logger.info(f"Successfully reduced components from {original_components} to {connected_components} " + "using minimal connections") + mask = connected_seg + # else: + # logger.info("No connections made, falling back to dilation approach") + # # Fallback: use the original dilation approach + # struct1 = ndimage.generate_binary_structure(3, 3) + # mask = ndimage.binary_dilation(seg_arr, structure=struct1, iterations=1).astype(np.uint8) + + # Get connected components from the processed mask labels_cc = label(mask, connectivity=3, background=0) - # Get componnets count + + # Get component counts bincount = np.bincount(labels_cc.flat) - # Get background label, assumption that background is the biggest connected component + + # Get background label (assumed to be the largest component) background = np.argmax(bincount) bincount[background] = -1 + # Get largest connected component largest_cc = labels_cc == np.argmax(bincount) - # Apply mask - clean_seg = seg_arr * largest_cc - return clean_seg,largest_cc + return largest_cc -def clean_cc_segmentation(seg_arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]: +def clean_cc_segmentation(seg_arr: np.ndarray, max_connection_distance: float = 3.0) -> tuple[np.ndarray, np.ndarray]: """Clean corpus callosum segmentation by removing non-connected components. This function processes a segmentation array to clean up the corpus callosum (CC) by removing non-connected components. It first isolates the CC (label 192), - removes non-connected components, then adds the fornix (label 250), and - finally removes non-connected components from the combined CC and fornix. + attempts to connect nearby disconnected components, then adds the fornix (label 250), + and finally removes non-connected components from the combined CC and fornix. Args: seg_arr (np.ndarray): Input segmentation array with CC (192) and fornix (250) labels + max_connection_distance (float): Maximum distance to connect components (default: 3.0) Returns: tuple: A tuple containing: @@ -191,23 +412,24 @@ def clean_cc_segmentation(seg_arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]: connected component of CC and fornix - mask (np.ndarray): Binary mask of the largest connected component """ - #Remove non connected components from the CC alone - clean_seg = np.zeros_like(seg_arr) - clean_seg[seg_arr == CC_LABEL] = CC_LABEL - clean_seg,_ = get_largest_cc(clean_seg) + # Remove non connected components from the CC alone, with minimal connections + cc_seg = np.zeros_like(seg_arr) + cc_seg[seg_arr == CC_LABEL] = CC_LABEL - #Add fornix to the CC labels - clean_seg[seg_arr == FORNIX_LABEL] = FORNIX_LABEL + cc_label_cleaned = np.zeros_like(cc_seg) + for i in range(cc_seg.shape[0]): + cc_label_cleaned[i] = get_largest_cc(cc_seg[None,i], max_connection_distance) + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots(1,3) + # ax[0].imshow(cc_seg[i]) + # ax[1].imshow(mask[i]) + # ax[2].imshow(cc_seg[i] - mask[i]*CC_LABEL) # difference between pre and post clean + # plt.show() - #Remove non connected components from CC & Fornix - clean_seg, mask = get_largest_cc(clean_seg) - unique_labels = np.unique(clean_seg) + # Add fornix to the CC labels + clean_seg = np.zeros_like(seg_arr) + clean_seg[cc_label_cleaned > 0] = CC_LABEL + clean_seg[seg_arr == FORNIX_LABEL] = FORNIX_LABEL - if 250 not in unique_labels: - clean_seg[seg_arr == 250] = 250 - mask[seg_arr == 250] = True - if 192 not in unique_labels: - clean_seg[seg_arr == 192] = 192 - mask[seg_arr == 192] = True - return clean_seg, mask + return clean_seg, cc_label_cleaned > 0 diff --git a/CorpusCallosum/shape/cc_endpoint_heuristic.py b/CorpusCallosum/shape/cc_endpoint_heuristic.py index 10a9a775..83df78c0 100644 --- a/CorpusCallosum/shape/cc_endpoint_heuristic.py +++ b/CorpusCallosum/shape/cc_endpoint_heuristic.py @@ -16,9 +16,117 @@ import numpy as np import scipy.ndimage import skimage.measure +from scipy.ndimage import label -def get_endpoints(cc_mask, AC_2d, PC_2d, resolution, return_coordinates=True, contour_smoothing=1.0): +def smooth_contour(x, y, window_size): + """ + Smooth a contour using a moving average filter. + """ + + # Ensure the window size is odd + if window_size % 2 == 0: + window_size += 1 + + # Create a padded version of the arrays to handle the edges + x_padded = np.pad(x, (window_size // 2, window_size // 2), mode="wrap") + y_padded = np.pad(y, (window_size // 2, window_size // 2), mode="wrap") + + # Apply moving average + x_smoothed = np.zeros_like(x) + y_smoothed = np.zeros_like(y) + + for i in range(len(x)): + x_smoothed[i] = np.mean(x_padded[i : i + window_size]) + y_smoothed[i] = np.mean(y_padded[i : i + window_size]) + + # remove padding + x_smoothed = x_smoothed[window_size // 2:-window_size // 2] + y_smoothed = y_smoothed[window_size // 2:-window_size // 2] + + return x_smoothed, y_smoothed + + +def connect_diagonally_connected_components(cc_mask): + """ + Connects diagonally connected components in the CC mask. + """ + + # Create padded mask to handle boundary conditions + padded_mask = np.pad(cc_mask, pad_width=1, mode='constant', constant_values=0) + + # Get center pixels and diagonal neighbors + center = padded_mask[1:-1, 1:-1] + + # Direct neighbors (4-connectivity) + left = padded_mask[1:-1, :-2] # left + right = padded_mask[1:-1, 2:] # right + up = padded_mask[:-2, 1:-1] # up + down = padded_mask[2:, 1:-1] # down + + # Diagonal neighbors + up_left = padded_mask[:-2, :-2] # up-left + up_right = padded_mask[:-2, 2:] # up-right + down_left = padded_mask[2:, :-2] # down-left + down_right = padded_mask[2:, 2:] # down-right + + potential_diagonal_gaps = (center == 0) & ( + ((up_left > 0) & ((right > 0) | (down > 0))) | + ((up_right > 0) & ((left > 0) | (down > 0))) | + ((down_left > 0) & ((right > 0) | (up > 0))) | + ((down_right > 0) & ((left > 0) | (up > 0))) + ) + + + # Get connected components before filling using 4-connectivity + # This way, diagonal-only connections are treated as separate components + structure_4conn = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]) + _, num_components_before = label(cc_mask, structure=structure_4conn) + + # For each potential gap, check if filling it would reduce the number of components + connects_diagonals = np.zeros_like(potential_diagonal_gaps) + gap_positions = np.where(potential_diagonal_gaps) + + for i, j in zip(gap_positions[0], gap_positions[1], strict=True): + # Temporarily fill this gap + test_mask = cc_mask.copy() + test_mask[i, j] = 1 + + # Check connected components after filling + _, num_components_after = label(test_mask, structure=structure_4conn) + + # Only fill if it actually connects previously disconnected components + if num_components_after < num_components_before: + connects_diagonals[i, j] = True + + # Fill the identified diagonal gaps that actually improve connectivity + cc_mask[connects_diagonals] = 1 + + +def extract_cc_contour(cc_mask, contour_smoothing=5): + """ + Extracts the contour of the CC from the mask. + """ + # cc_mask_orig = cc_mask + cc_mask = cc_mask.copy() + + connect_diagonally_connected_components(cc_mask) + + contour = skimage.measure.find_contours(cc_mask, level=0.5)[0].T + contour = np.array(smooth_contour(contour[0], contour[1], contour_smoothing)) + + # plot contour + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots(1,2,figsize=(10, 8)) + # ax[0].imshow(cc_mask_orig) + # ax[1].imshow(cc_mask) + # ax[0].plot(contour[1], contour[0], 'r-') + # ax[1].plot(contour[1], contour[0], 'r-') + # plt.show() + + return contour + +def get_endpoints(cc_mask, AC_2d, PC_2d, resolution, return_coordinates=True, contour_smoothing=5): """ Determines endpoints of CC by finding the point in the contour closest to the anterior and posterior commisure (with some offsets) @@ -39,6 +147,8 @@ def get_endpoints(cc_mask, AC_2d, PC_2d, resolution, return_coordinates=True, co theta_degrees = theta * 180 / np.pi rotated_cc_mask = scipy.ndimage.rotate(cc_mask, -theta_degrees, order=0, reshape=False) + contour = extract_cc_contour(rotated_cc_mask, contour_smoothing) + # rotate points around center origin_point = np.array([image_size[0] // 2, image_size[1] // 2]) @@ -51,13 +161,7 @@ def get_endpoints(cc_mask, AC_2d, PC_2d, resolution, return_coordinates=True, co rotated_PC_2d = (rot_matrix @ pc_centered) + origin_point rotated_AC_2d = (rot_matrix @ ac_centered) + origin_point - - # get contour of CC - gaussian_cc_mask = scipy.ndimage.gaussian_filter(rotated_cc_mask.astype(float), sigma=contour_smoothing) - # gaussian_cc_mask = scipy.ndimage.gaussian_filter(gaussian_cc_mask, sigma=1.0) - contour = skimage.measure.find_contours(gaussian_cc_mask, level=0.5)[0].T - - + # Add z=0 coordinate to make 3D, then remove it after resampling contour_3d = np.vstack([contour, np.zeros(contour.shape[1])]) diff --git a/CorpusCallosum/shape/cc_mesh.py b/CorpusCallosum/shape/cc_mesh.py index fd78ae2e..55dbc37b 100644 --- a/CorpusCallosum/shape/cc_mesh.py +++ b/CorpusCallosum/shape/cc_mesh.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys import tempfile from pathlib import Path @@ -28,6 +27,7 @@ from whippersnappy.core import snap1 import FastSurferCNN.utils.logging as logging +from CorpusCallosum.shape.cc_endpoint_heuristic import smooth_contour from CorpusCallosum.shape.cc_thickness import HiddenPrints, make_mesh_from_contour logger = logging.get_logger(__name__) @@ -802,23 +802,9 @@ def smooth_contour(self, contour_idx, window_size=5): """ x, y = self.contours[contour_idx].T - # Ensure the window size is odd - if window_size % 2 == 0: - window_size += 1 + x, y = smooth_contour(x, y, window_size) - # Create a padded version of the arrays to handle the edges - x_padded = np.pad(x, (window_size // 2, window_size // 2), mode="wrap") - y_padded = np.pad(y, (window_size // 2, window_size // 2), mode="wrap") - - # Apply moving average - x_smoothed = np.zeros_like(x) - y_smoothed = np.zeros_like(y) - - for i in range(len(x)): - x_smoothed[i] = np.mean(x_padded[i : i + window_size]) - y_smoothed[i] = np.mean(y_padded[i : i + window_size]) - - self.contours[contour_idx] = np.array([x_smoothed, y_smoothed]).T + self.contours[contour_idx] = np.array([x, y]).T def plot_cc_contour_with_levelsets(self, contour_idx=0, levelpaths=None, title=None, save_path=None, colorbar=True): """Plot a contour with levelset visualization. @@ -1348,7 +1334,9 @@ def to_fs_coordinates(self, vox_size: tuple[int, int, int], image_size: tuple[in self.v = self.v[:, [2, 0, 1]] # LIA to ALI? self.v *= (vox_size[0] **2) ## ??? self.v[:, 1] -= image_size[1] * vox_size[1] // 2 # move 0 to center of image - self.v[:, 2] += image_size[2] * vox_size[2] // 2 + self.v[:, 2] += image_size[2] * vox_size[2] // 2 + self.v[:, 0] += vox_size[0] / 2 + def write_fssurf(self, filename): diff --git a/CorpusCallosum/shape/cc_thickness.py b/CorpusCallosum/shape/cc_thickness.py index 05dcd55a..640db397 100644 --- a/CorpusCallosum/shape/cc_thickness.py +++ b/CorpusCallosum/shape/cc_thickness.py @@ -199,7 +199,7 @@ def cc_thickness( # plot mesh points with index next to point # import matplotlib.pyplot as plt # fig, ax = plt.subplots(figsize=(10, 8)) - # ax.scatter(mesh_points[:,0], mesh_points[:,1], label='Mesh Points') + # ax.plot(mesh_points[:,0], mesh_points[:,1], label='Mesh Points') # for i in range(len(mesh_points)): # ax.text(mesh_points[i,0], mesh_points[i,1], str(i), fontsize=7) # plt.show() From 5e6970471e659513c7d7cfece51ac06970442872 Mon Sep 17 00:00:00 2001 From: ClePol Date: Wed, 1 Oct 2025 14:34:20 +0200 Subject: [PATCH 19/68] fixed freesurfer surface conversion and scaling issues --- CorpusCallosum/cc_visualization.py | 6 +- CorpusCallosum/data/fsaverage_data.json | 142 +++++++++++------- CorpusCallosum/data/read_write.py | 3 +- CorpusCallosum/fastsurfer_cc.py | 9 +- CorpusCallosum/shape/cc_endpoint_heuristic.py | 125 ++++++++++----- CorpusCallosum/shape/cc_mesh.py | 35 ++++- CorpusCallosum/shape/cc_postprocessing.py | 19 +-- CorpusCallosum/shape/cc_thickness.py | 4 +- CorpusCallosum/visualization/visualization.py | 8 +- 9 files changed, 229 insertions(+), 122 deletions(-) diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index be447050..e6511c12 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -3,7 +3,9 @@ import numpy as np +from CorpusCallosum.data.constants import FSAVERAGE_DATA_PATH from CorpusCallosum.data.fsaverage_cc_template import load_fsaverage_cc_template +from CorpusCallosum.data.read_write import load_fsaverage_data from CorpusCallosum.shape.cc_mesh import CC_Mesh @@ -93,6 +95,8 @@ def main( # Load data and create mesh cc_mesh = CC_Mesh(num_slices=1) # Will be resized when loading data + _, _, vox2ras_tkr = load_fsaverage_data(FSAVERAGE_DATA_PATH) + if contours_path is not None: cc_mesh.load_contours(str(contours_path)) else: @@ -128,7 +132,7 @@ def main( contour_idx=len(cc_mesh.contours) // 2, save_path=str(output_dir / "midslice_2d.png") ) - cc_mesh.to_fs_coordinates(vox_size=[resolution, resolution, resolution]) + cc_mesh.to_fs_coordinates(vox_size=[resolution, resolution, resolution], vox2ras_tkr=vox2ras_tkr) cc_mesh.write_vtk(str(output_dir / "cc_mesh.vtk")) cc_mesh.write_fssurf(str(output_dir / "cc_mesh.fssurf")) cc_mesh.write_overlay(str(output_dir / "cc_mesh_overlay.curv")) diff --git a/CorpusCallosum/data/fsaverage_data.json b/CorpusCallosum/data/fsaverage_data.json index 9efa2336..0fdd17fb 100644 --- a/CorpusCallosum/data/fsaverage_data.json +++ b/CorpusCallosum/data/fsaverage_data.json @@ -1,62 +1,88 @@ { - "affine": [ - [ - -1.0, - 0.0, - 0.0, - 128.0 + "affine": [ + [ + -1.0, + 0.0, + 0.0, + 128.0 + ], + [ + 0.0, + 0.0, + 1.0, + -128.0 + ], + [ + 0.0, + -1.0, + 0.0, + 128.0 + ], + [ + 0.0, + 0.0, + 0.0, + 1.0 + ] ], - [ - 0.0, - 0.0, - 1.0, - -128.0 - ], - [ - 0.0, - -1.0, - 0.0, - 128.0 - ], - [ - 0.0, - 0.0, - 0.0, - 1.0 - ] - ], - "header": { - "dims": [ - 256, - 256, - 256 - ], - "delta": [ - 1.0, - 1.0, - 1.0 - ], - "Mdc": [ - [ - -1.0, - 0.0, - 0.0 - ], - [ - 0.0, - 0.0, - 10000000000.0 - ], - [ - 0.0, - -10000000000.0, - 0.0 - ] - ], - "Pxyz_c": [ - 128.0, - -128.0, - 128.0 + "header": { + "dims": [ + 256, + 256, + 256 + ], + "delta": [ + 1.0, + 1.0, + 1.0 + ], + "Mdc": [ + [ + -1.0, + 0.0, + 0.0 + ], + [ + 0.0, + 0.0, + 10000000000.0 + ], + [ + 0.0, + -10000000000.0, + 0.0 + ] + ], + "Pxyz_c": [ + 128.0, + -128.0, + 128.0 + ] + }, + "vox2ras_tkr": [ + [ + -1.0, + 0.0, + 0.0, + 128.0 + ], + [ + 0.0, + 0.0, + 1.0, + -128.0 + ], + [ + 0.0, + -1.0, + 0.0, + 128.0 + ], + [ + 0.0, + 0.0, + 0.0, + 1.0 + ] ] - } } \ No newline at end of file diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py index 786b97ef..e4a86234 100644 --- a/CorpusCallosum/data/read_write.py +++ b/CorpusCallosum/data/read_write.py @@ -268,6 +268,7 @@ def load_fsaverage_data(data_path): # Convert lists back to numpy arrays affine_matrix = np.array(data["affine"]) + vox2ras_tkr = np.array(data["vox2ras_tkr"]) header_data = data["header"].copy() header_data["Mdc"] = np.array(header_data["Mdc"]) header_data["Pxyz_c"] = np.array(header_data["Pxyz_c"]) @@ -276,4 +277,4 @@ def load_fsaverage_data(data_path): if affine_matrix.shape != (4, 4): raise ValueError(f"Expected 4x4 affine matrix, got shape {affine_matrix.shape}") - return affine_matrix, header_data \ No newline at end of file + return affine_matrix, header_data, vox2ras_tkr \ No newline at end of file diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 2a3ab95c..70fbc00d 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -278,7 +278,7 @@ def centroid_registration(aseg_nib, verbose=False): # Load pre-computed fsaverage centroids and data from static files centroids_dst = load_fsaverage_centroids(FSAVERAGE_CENTROIDS_PATH) - fsaverage_affine, fsaverage_header = load_fsaverage_data(FSAVERAGE_DATA_PATH) + fsaverage_affine, fsaverage_header, vox2ras_tkr = load_fsaverage_data(FSAVERAGE_DATA_PATH) centroids_mov, ids_not_found = get_centroids_from_nib(aseg_nib, label_ids=list(centroids_dst.keys())) @@ -303,7 +303,7 @@ def centroid_registration(aseg_nib, verbose=False): ) fsaverage_hires_affine = resolution_trans @ fsaverage_affine - return orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, fsaverage_hires_affine, fsaverage_header + return orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, fsaverage_hires_affine, fsaverage_header, vox2ras_tkr def localize_ac_pc(midslices, aseg_nib, orig_fsaverage_vox2vox, model_localization, slices_to_analyze): @@ -571,7 +571,8 @@ def main( aseg_nib = nib.load(aseg_path) logger.info("Performing centroid registration to fsaverage space") - orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, fsaverage_hires_affine, fsaverage_header = centroid_registration( + (orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, + fsaverage_hires_affine, fsaverage_header, fsaverage_vox2ras_tkr) = centroid_registration( aseg_nib, verbose=False ) @@ -651,7 +652,7 @@ def main( vtk_file_path=vtk_file_path, thickness_image_path=thickness_image_path, vox_size=orig.header.get_zooms(), - image_size=orig.shape, + vox2ras_tkr=fsaverage_vox2ras_tkr, verbose=verbose, save_template=save_template, ) diff --git a/CorpusCallosum/shape/cc_endpoint_heuristic.py b/CorpusCallosum/shape/cc_endpoint_heuristic.py index 83df78c0..8bf9b8c6 100644 --- a/CorpusCallosum/shape/cc_endpoint_heuristic.py +++ b/CorpusCallosum/shape/cc_endpoint_heuristic.py @@ -13,16 +13,39 @@ # limitations under the License. import lapy +import nibabel import numpy as np +import numpy.typing as npt +import pandas as pd import scipy.ndimage import skimage.measure from scipy.ndimage import label -def smooth_contour(x, y, window_size): - """ - Smooth a contour using a moving average filter. +def smooth_contour(x: npt.NDArray[np.float64], + y: npt.NDArray[np.float64], + window_size: int) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: + """Smooth a contour using a moving average filter. + + Parameters + ---------- + x : npt.NDArray[np.float64] + x-coordinates of the contour points + y : npt.NDArray[np.float64] + y-coordinates of the contour points + window_size : int + Size of the smoothing window. Must be odd and > 2. + + Returns + ------- + tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]] + Smoothed x and y coordinates of the contour """ + # Ensure window_size is an integer + window_size = int(window_size) + + if window_size // 2 == 0: + raise ValueError(f"Smoothing window size of {window_size} is too small") # Ensure the window size is odd if window_size % 2 == 0: @@ -47,9 +70,17 @@ def smooth_contour(x, y, window_size): return x_smoothed, y_smoothed -def connect_diagonally_connected_components(cc_mask): - """ - Connects diagonally connected components in the CC mask. +def connect_diagonally_connected_components(cc_mask: npt.NDArray[np.bool_]) -> None: + """Connect diagonally connected components in the CC mask. + + Parameters + ---------- + cc_mask : npt.NDArray[np.bool_] + Binary mask of the corpus callosum + + Notes + ----- + Modifies the input mask in-place to connect diagonally connected components. """ # Create padded mask to handle boundary conditions @@ -103,9 +134,21 @@ def connect_diagonally_connected_components(cc_mask): cc_mask[connects_diagonals] = 1 -def extract_cc_contour(cc_mask, contour_smoothing=5): - """ - Extracts the contour of the CC from the mask. +def extract_cc_contour(cc_mask: npt.NDArray[np.bool_], + contour_smoothing: int = 5) -> npt.NDArray[np.float64]: + """Extract the contour of the CC from the mask. + + Parameters + ---------- + cc_mask : npt.NDArray[np.bool_] + Binary mask of the corpus callosum + contour_smoothing : int, optional + Window size for contour smoothing, by default 5 + + Returns + ------- + npt.NDArray[np.float64] + Array of shape (2, N) containing x,y coordinates of the contour points """ # cc_mask_orig = cc_mask cc_mask = cc_mask.copy() @@ -126,12 +169,42 @@ def extract_cc_contour(cc_mask, contour_smoothing=5): return contour -def get_endpoints(cc_mask, AC_2d, PC_2d, resolution, return_coordinates=True, contour_smoothing=5): - """ - Determines endpoints of CC by finding the point in the contour closest to - the anterior and posterior commisure (with some offsets) - - NOTE: Expects LIA orientation +def get_endpoints(cc_mask: npt.NDArray[np.bool_], + AC_2d: npt.NDArray[np.float64], + PC_2d: npt.NDArray[np.float64], + resolution: float, + return_coordinates: bool = True, + contour_smoothing: int = 5) -> ( + tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]] | + tuple[npt.NDArray[np.float64], int, int]): + """Determine endpoints of CC by finding points closest to AC and PC. + + Parameters + ---------- + cc_mask : npt.NDArray[np.bool_] + Binary mask of the corpus callosum + AC_2d : npt.NDArray[np.float64] + 2D coordinates of the anterior commissure + PC_2d : npt.NDArray[np.float64] + 2D coordinates of the posterior commissure + resolution : float + Image resolution in mm + return_coordinates : bool, optional + If True, return endpoint coordinates, otherwise return indices, by default True + contour_smoothing : int, optional + Window size for contour smoothing, by default 5 + + Returns + ------- + tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]] | tuple[npt.NDArray[np.float64], int, int] + If return_coordinates is True: + (contour, anterior_point, posterior_point) + If return_coordinates is False: + (contour, anterior_index, posterior_index) + + Notes + ----- + Expects LIA orientation """ image_size = cc_mask.shape @@ -215,25 +288,3 @@ def get_endpoints(cc_mask, AC_2d, PC_2d, resolution, return_coordinates=True, co return contour_rotated, start_point_A, start_point_P else: return contour_rotated, AC_startpoint_idx, PC_startpoint_idx - - -def get_endpoints_from_nib(cc_label_nib, paths_csv, subj_id, return_coordinates=True): - cc_mask = cc_label_nib.get_fdata() == 192 - cc_mask = cc_mask[cc_mask.shape[0] // 2] - - posterior_commisure_center = paths_csv.loc[subj_id, "PC_center_r":"PC_center_s"].to_numpy().astype(float) - anterior_commisure_center = paths_csv.loc[subj_id, "AC_center_r":"AC_center_s"].to_numpy().astype(float) - - # adjust LR from label coordinates to orig_up coordinates - posterior_commisure_center[0] = 128 - anterior_commisure_center[0] = 128 - - # orientation I, A - # rotate image so anterior and posterior commisure are horizontal - AC_2d = anterior_commisure_center[1:] - PC_2d = posterior_commisure_center[1:] - - return get_endpoints( - cc_mask, AC_2d, PC_2d, resolution=cc_label_nib.header.get_zooms()[1], return_coordinates=return_coordinates - ) - diff --git a/CorpusCallosum/shape/cc_mesh.py b/CorpusCallosum/shape/cc_mesh.py index 55dbc37b..41fb9149 100644 --- a/CorpusCallosum/shape/cc_mesh.py +++ b/CorpusCallosum/shape/cc_mesh.py @@ -1325,17 +1325,40 @@ def __make_parent_folder(filename: str): output_folder = Path(filename).parent output_folder.mkdir(parents=False, exist_ok=True) - def to_fs_coordinates(self, vox_size: tuple[int, int, int], image_size: tuple[int, int, int]): + def to_fs_coordinates(self, vox2ras_tkr: np.ndarray, vox_size: tuple[int, int, int]): """Convert mesh coordinates to FreeSurfer coordinate system. Transforms the mesh vertices from the original coordinate system to the FreeSurfer coordinate system by reordering axes and applying appropriate offsets. """ - self.v = self.v[:, [2, 0, 1]] # LIA to ALI? - self.v *= (vox_size[0] **2) ## ??? - self.v[:, 1] -= image_size[1] * vox_size[1] // 2 # move 0 to center of image - self.v[:, 2] += image_size[2] * vox_size[2] // 2 - self.v[:, 0] += vox_size[0] / 2 + + # to voxel coordinates + v_vox = self.v.copy() + + # to LSA + v_vox = v_vox[:, [2, 1, 0]] + # to voxel + v_vox /= vox_size[0] + # center LR + v_vox[:, 0] += 256 // 2 + # flip SI + v_vox[:, 1] = -v_vox[:, 1] + + + #v_vox_test = np.round(v_vox).astype(int) + ## write volume for debugging + # contour_img = np.zeros(orig.shape) + # for i in range(v_vox_test.shape[0]): + # contour_img[v_vox_test[i, 0], v_vox_test[i, 1], v_vox_test[i, 2]] = 1 + + # tkrRAS = Torig*[C R S 1]' + # Torig: mri_info --vox2ras-tkr orig.mgz + # https://surfer.nmr.mgh.harvard.edu/fswiki/CoordinateSystems + self.v = (vox2ras_tkr @ np.concatenate([v_vox, np.ones((self.v.shape[0], 1))], axis=1).T).T[:, :3] + self.v = self.v * vox_size[0] + + + diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/cc_postprocessing.py index 99e26350..b7aedb9d 100644 --- a/CorpusCallosum/shape/cc_postprocessing.py +++ b/CorpusCallosum/shape/cc_postprocessing.py @@ -104,7 +104,7 @@ def create_slice_affine(temp_seg_affine, slice_idx, fsaverage_middle): def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thickness_points, subdivisions, - subdivision_method, contour_smoothing): + subdivision_method, contour_smoothing, vox_size): """Process a single slice for corpus callosum measurements. Performs detailed analysis of a corpus callosum slice, including: @@ -164,7 +164,7 @@ def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thi contour, anterior_endpoint_idx, posterior_endpoint_idx = get_endpoints(cc_mask_slice, ac_coords, pc_coords, - affine.diagonal()[1], + vox_size, return_coordinates=False, contour_smoothing=contour_smoothing) contour_1mm = convert_to_ras(contour, affine) @@ -194,7 +194,6 @@ def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thi contour_1mm[:,anterior_endpoint_idx], contour_1mm[:,posterior_endpoint_idx])[0] for split_contour in split_contours] - split_contours_hofer_frahm = None elif subdivision_method == "vertical": areas, split_contours = subdivide_contour(contour_acpc, subdivisions, plot=False) @@ -245,7 +244,7 @@ def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thi def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac_coords, pc_coords, num_thickness_points, subdivisions, subdivision_method, contour_smoothing, debug_image_path=None, one_debug_image=False, - thickness_image_path=None, vox_size=None, image_size=None, + thickness_image_path=None, vox_size=None, vox2ras_tkr=None, save_template=None, surf_file_path=None, overlay_file_path=None, cc_html_path=None, vtk_file_path=None, verbose=False): """Process corpus callosum slices based on selection mode. @@ -280,7 +279,7 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac if slice_selection == "middle": cc_mesh = CC_Mesh(num_slices=1) cc_mesh.set_acpc_coords(ac_coords, pc_coords) - cc_mesh.set_resolution(vox_size[0]) # contour is always scaled to 1 mm + cc_mesh.set_resolution(vox_size[0]) # Process only the middle slice slice_idx = segmentation.shape[0] // 2 @@ -295,7 +294,8 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac num_thickness_points, subdivisions, subdivision_method, - contour_smoothing) + contour_smoothing, + vox_size[0]) cc_mesh.add_contour(0, contour_with_thickness[0], @@ -313,7 +313,7 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac num_slices = segmentation.shape[0] cc_mesh = CC_Mesh(num_slices=num_slices) cc_mesh.set_acpc_coords(ac_coords, pc_coords) - cc_mesh.set_resolution(vox_size[0]) # contour is always scaled to 1 mm + cc_mesh.set_resolution(vox_size[0]) # Process multiple slices or specific slice if slice_selection == "all": @@ -337,7 +337,8 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac ac_coords, pc_coords, slice_affine, num_thickness_points, subdivisions, subdivision_method, - contour_smoothing) + contour_smoothing, + vox_size[0]) # insert cc_mesh.add_contour(slice_idx, @@ -393,7 +394,7 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac #cc_mesh.write_vtk(str(output_dir / 'cc_mesh.vtk')) - cc_mesh.to_fs_coordinates(vox_size=vox_size, image_size=image_size) + cc_mesh.to_fs_coordinates(vox2ras_tkr=vox2ras_tkr, vox_size=vox_size) if overlay_file_path is not None: if verbose: diff --git a/CorpusCallosum/shape/cc_thickness.py b/CorpusCallosum/shape/cc_thickness.py index 640db397..478eae16 100644 --- a/CorpusCallosum/shape/cc_thickness.py +++ b/CorpusCallosum/shape/cc_thickness.py @@ -63,8 +63,8 @@ def convert_to_ras(contour, vox2ras_matrix, get_parameters=False): # get scaling by getting length of three column vectors scaling = np.linalg.norm(vox2ras_matrix[:3, :3], axis=0) - # apply transformation - contour = (contour.T / scaling[1:]).T + # voxel * vox_size = mm + contour = (contour.T * scaling[1:]).T if get_parameters: return contour, anterior_reversed, superior_reversed, swap_axes diff --git a/CorpusCallosum/visualization/visualization.py b/CorpusCallosum/visualization/visualization.py index bcb3b8af..3574045b 100644 --- a/CorpusCallosum/visualization/visualization.py +++ b/CorpusCallosum/visualization/visualization.py @@ -125,15 +125,15 @@ def plot_contours( # scale contour data by vox_size split_contours = ( - [split_contour * vox_size for split_contour in split_contours] if split_contours is not None else None + [split_contour / vox_size for split_contour in split_contours] if split_contours is not None else None ) split_contours_hofer_frahm = ( - [split_contour * vox_size for split_contour in split_contours_hofer_frahm] + [split_contour / vox_size for split_contour in split_contours_hofer_frahm] if split_contours_hofer_frahm is not None else None ) - midline_equidistant = midline_equidistant * vox_size - levelpaths = [levelpath * vox_size for levelpath in levelpaths] + midline_equidistant = midline_equidistant / vox_size + levelpaths = [levelpath / vox_size for levelpath in levelpaths] NO_PLOTS = 1 if split_contours is not None: From 763409f0c59b3fffcea05eac5ae9ceb5eaf001e2 Mon Sep 17 00:00:00 2001 From: ClePol Date: Thu, 2 Oct 2025 14:00:51 +0200 Subject: [PATCH 20/68] docstrings, typehints and small bugfixes --- CorpusCallosum/cc_visualization.py | 50 +- CorpusCallosum/data/fsaverage_cc_template.py | 117 +-- .../data/generate_fsaverage_centroids.py | 28 +- CorpusCallosum/data/read_write.py | 227 +++--- CorpusCallosum/fastsurfer_cc.py | 293 ++++--- .../localization/localization_inference.py | 145 +++- CorpusCallosum/paint_cc_into_pred.py | 26 +- .../registration/mapping_helpers.py | 285 ++++--- .../segmentation/segmentation_inference.py | 191 +++-- .../segmentation_postprocessing.py | 296 ++++--- CorpusCallosum/shape/cc_endpoint_heuristic.py | 80 +- CorpusCallosum/shape/cc_mesh.py | 746 ++++++++++++------ CorpusCallosum/shape/cc_metrics.py | 37 +- CorpusCallosum/shape/cc_postprocessing.py | 435 ++++++---- CorpusCallosum/shape/cc_subsegment_contour.py | 174 +++- CorpusCallosum/shape/cc_thickness.py | 186 ++++- .../transforms/localization_transforms.py | 61 +- .../transforms/segmentation_transforms.py | 105 ++- CorpusCallosum/utils/utils.py | 38 +- CorpusCallosum/visualization/visualization.py | 182 +++-- 20 files changed, 2514 insertions(+), 1188 deletions(-) diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index e6511c12..8461616d 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -23,7 +23,7 @@ def options_parse() -> argparse.Namespace: parser.add_argument("--output_dir", type=str, required=True, help="Directory for output files") parser.add_argument("--resolution", type=float, default=1.0, help="Resolution in mm for the mesh") parser.add_argument( - "--smooth_iterations", type=int, default=1, help="Number of smoothing iterations to apply to the mesh" + "--smoothing_window", type=int, default=5, help="Window size for smoothing the contour" ) parser.add_argument( "--colormap", @@ -57,7 +57,7 @@ def main( measurement_points_path: str | Path, output_dir: str | Path, resolution: float = 1.0, - smooth_iterations: int = 1, + smoothing_window: int = 5, colormap: str = "red_to_yellow", color_range: tuple[float, float] | None = None, legend: str | None = None, @@ -65,26 +65,36 @@ def main( ) -> None: """Main function to visualize corpus callosum from template files. - This function: - 1. Loads contours and thickness values from template files - 2. Creates a CC_Mesh object - 3. Generates and saves visualizations - - Args: - contours_path: Path to contours.txt file - thickness_path: Path to thickness_values.txt file - measurement_points_path: Path to file containing the original vertex indices where thickness was measured - output_dir: Directory for output files - resolution: Resolution in mm for the mesh - smooth_iterations: Number of smoothing iterations to apply to the mesh - colormap: Which colormap to use. Options are: + This function loads contours and thickness values from template files, + creates a CC_Mesh object, and generates visualizations. + + Parameters + ---------- + contours_path : str or Path or None + Path to contours.txt file. + thickness_path : str or Path + Path to thickness_values.txt file. + measurement_points_path : str or Path + Path to file containing original vertex indices where thickness was measured. + output_dir : str or Path + Directory for output files. + resolution : float, optional + Resolution in mm for the mesh, by default 1.0. + smoothing_window : int, optional + Window size for smoothing the contour, by default 5. + colormap : str, optional + Colormap to use for visualization, by default "red_to_yellow". + Options: - "red_to_blue": Red -> Orange -> Grey -> Light Blue -> Blue - "blue_to_red": Blue -> Light Blue -> Grey -> Orange -> Red - "red_to_yellow": Red -> Yellow -> Light Blue -> Blue - "yellow_to_red": Yellow -> Light Blue -> Blue -> Red - - color_range: Optional tuple of (min, max) to set fixed color range for the colorbar - twoD: If True, generate 2D visualization instead of 3D mesh + color_range : tuple[float, float], optional + Fixed range (min, max) for the colorbar, by default None. + legend : str, optional + Legend for the colorbar, by default None. + twoD : bool, optional + If True, generate 2D visualization instead of 3D mesh, by default False. """ # Convert paths to Path objects contours_path = Path(contours_path) if contours_path is not None else None @@ -115,7 +125,7 @@ def main( else: cc_mesh.fill_thickness_values() # Create and process mesh - cc_mesh.create_mesh(smooth=smooth_iterations, closed=False) + cc_mesh.create_mesh(smooth=smoothing_window, closed=False) # Generate visualizations cc_mesh.plot_mesh( @@ -147,7 +157,7 @@ def main( "measurement_points_path": options.measurement_points, "output_dir": options.output_dir, "resolution": options.resolution, - "smooth_iterations": options.smooth_iterations, + "smoothing_window": options.smoothing_window, "colormap": options.colormap, "color_range": options.color_range, "legend": options.legend, diff --git a/CorpusCallosum/data/fsaverage_cc_template.py b/CorpusCallosum/data/fsaverage_cc_template.py index 196873ed..3e0cd854 100644 --- a/CorpusCallosum/data/fsaverage_cc_template.py +++ b/CorpusCallosum/data/fsaverage_cc_template.py @@ -22,43 +22,65 @@ from CorpusCallosum.shape.cc_postprocessing import process_slice -def smooth_contour(contour, window_size=5): - """ - Smooth a contour using a moving average filter. - - Parameters - ---------- - contour : tuple of arrays - The contour coordinates (x, y). - window_size : int - Size of the smoothing window. - - Returns - ------- - tuple of arrays - The smoothed contour coordinates (x, y). - """ - x, y = contour - - # Ensure the window size is odd - if window_size % 2 == 0: - window_size += 1 - - # Create a padded version of the arrays to handle the edges - x_padded = np.pad(x, (window_size//2, window_size//2), mode='wrap') - y_padded = np.pad(y, (window_size//2, window_size//2), mode='wrap') - - # Apply moving average - x_smoothed = np.zeros_like(x) - y_smoothed = np.zeros_like(y) - - for i in range(len(x)): - x_smoothed[i] = np.mean(x_padded[i:i+window_size]) - y_smoothed[i] = np.mean(y_padded[i:i+window_size]) - - return (x_smoothed, y_smoothed) - -def load_fsaverage_cc_template(): +def smooth_contour(contour: tuple[np.ndarray, np.ndarray], window_size: int = 5) -> tuple[np.ndarray, np.ndarray]: + """Smooth a contour using a moving average filter. + + Parameters + ---------- + contour : tuple of arrays + The contour coordinates (x, y). + window_size : int + Size of the smoothing window. + + Returns + ------- + tuple of arrays + The smoothed contour coordinates (x, y). + + """ + x, y = contour + + # Ensure the window size is odd + if window_size % 2 == 0: + window_size += 1 + + # Create a padded version of the arrays to handle the edges + x_padded = np.pad(x, (window_size//2, window_size//2), mode='wrap') + y_padded = np.pad(y, (window_size//2, window_size//2), mode='wrap') + + # Apply moving average + x_smoothed = np.zeros_like(x) + y_smoothed = np.zeros_like(y) + + for i in range(len(x)): + x_smoothed[i] = np.mean(x_padded[i:i+window_size]) + y_smoothed[i] = np.mean(y_padded[i:i+window_size]) + + return (x_smoothed, y_smoothed) + + +def load_fsaverage_cc_template() -> tuple[ + np.ndarray, tuple[np.ndarray, np.ndarray], np.ndarray, np.ndarray, np.ndarray, tuple[int, int] +]: + """Load and process the fsaverage corpus callosum template. + + This function loads the fsaverage segmentation from FreeSurfer's data directory, + extracts the corpus callosum mask, and processes it to create a smooth template. + + Returns + ------- + tuple + Contains: + - contour : tuple[np.ndarray, np.ndarray] : x and y coordinates of the contour points. + - anterior_endpoint_idx : np.ndarray : Index of the anterior endpoint. + - posterior_endpoint_idx : np.ndarray : Index of the posterior endpoint. + + Raises + ------ + OSError + If FREESURFER_HOME environment variable is not set correctly. + + """ # smooth outside contour # Apply smoothing to the outside contour using a moving average @@ -101,16 +123,17 @@ def load_fsaverage_cc_template(): cc_mask = cc_mask_smoothed.astype(int) cc_mask[cc_mask > 0] = 192 - (_, contour_with_thickness, anterior_endpoint_idx, - posterior_endpoint_idx) = process_slice(segmentation=cc_mask[None], - slice_idx=0, - ac_coords=AC, - pc_coords=PC, - affine=fsaverage_seg.affine, - num_thickness_points=100, - subdivisions=[1/6, 1/2, 2/3, 3/4], - subdivision_method="shape", - contour_smoothing=1.0) + (_, contour_with_thickness, anterior_endpoint_idx, + posterior_endpoint_idx) = process_slice(segmentation=cc_mask[None], + slice_idx=0, + ac_coords=AC, + pc_coords=PC, + affine=fsaverage_seg.affine, + num_thickness_points=100, + subdivisions=[1/6, 1/2, 2/3, 3/4], + subdivision_method="shape", + contour_smoothing=5, + vox_size=1) outside_contour = contour_with_thickness[0].T diff --git a/CorpusCallosum/data/generate_fsaverage_centroids.py b/CorpusCallosum/data/generate_fsaverage_centroids.py index 443a7fb3..9b69beb5 100644 --- a/CorpusCallosum/data/generate_fsaverage_centroids.py +++ b/CorpusCallosum/data/generate_fsaverage_centroids.py @@ -33,8 +33,31 @@ logger = logging.get_logger(__name__) -def main(): - """Generate and save fsaverage centroids to a static file.""" +def main() -> None: + """Generate and save fsaverage centroids to a static file. + + This script extracts centroids from the fsaverage template segmentation + and saves them to a JSON file for fast loading during pipeline execution. + + The script performs the following steps: + 1. Load fsaverage segmentation from FreeSurfer directory + 2. Extract centroids for all anatomical structures + 3. Save centroids to JSON file + 4. Extract and save affine matrix and header fields + + Raises + ------ + OSError + If FREESURFER_HOME environment variable is not set or invalid + FileNotFoundError + If required fsaverage files are not found + + Notes + ----- + The script saves two files: + - fsaverage_centroids.json : Contains centroids for each anatomical structure + - fsaverage_data.json : Contains affine matrix and header information + """ # Get fsaverage path from FreeSurfer environment try: @@ -97,6 +120,7 @@ def main(): # Combine affine and header data combined_data = { "affine": fsaverage_affine.tolist(), # Convert numpy array to list for JSON serialization + "vox2ras_tkr": fsaverage_nib.header.get_vox2ras_tkr().tolist(), "header": { "dims": dims, "delta": delta, diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py index e4a86234..52df104d 100644 --- a/CorpusCallosum/data/read_write.py +++ b/CorpusCallosum/data/read_write.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import multiprocessing +from pathlib import Path import nibabel as nib import numpy as np @@ -22,22 +24,24 @@ logger = logging.get_logger(__name__) -def run_in_background(function, debug=False, *args, **kwargs): +def run_in_background(function: callable, debug: bool = False, *args, **kwargs) -> multiprocessing.Process | None: """Run a function in the background using multiprocessing. - This function executes the given function either in a separate process (normal mode) - or in the current process (debug mode). In debug mode, the function is executed - synchronously for easier debugging. + Parameters + ---------- + function : callable + The function to execute. + debug : bool, optional + If True, run synchronously in current process, by default False. + *args + Positional arguments to pass to the function. + **kwargs + Keyword arguments to pass to the function. - Args: - function: The function to execute - debug (bool): If True, run synchronously in current process - args: Positional arguments to pass to the function - kwargs: Keyword arguments to pass to the function - - Returns: - multiprocessing.Process or None: Process object if running in background, - None if in debug mode + Returns + ------- + multiprocessing.Process or None + Process object if running in background, None if in debug mode. """ if debug: function(*args, **kwargs) @@ -51,25 +55,21 @@ def run_in_background(function, debug=False, *args, **kwargs): def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: list[int] | None = None) -> dict[int, np.ndarray]: """Get centroids of segmentation labels in RAS coordinates. - - Calculates the centroid coordinates for each segmentation label in the image. - If label_ids is provided, only calculates centroids for those specific labels. - Coordinates are returned in RAS (Right-Anterior-Superior) coordinate system. - - Args: - seg_img (nib.Nifti1Image) - Nibabel image containing segmentation labels - label_ids (list[int] | None) - Optional list of specific label IDs to process. - If None, processes all non-zero labels. - - Returns: - centroids (dict | dict, list) - If label_ids is None, returns a dict mapping label IDs to their centroids (x,y,z) in RAS coordinates. - If label_ids is provided, returns a tuple containing: - - dict[int, np.ndarray]: Mapping of found label IDs to their centroids - - list[int]: List of label IDs that were not found in the image + Parameters + ---------- + seg_img : nibabel.Nifti1Image + Input segmentation image. + label_ids : list[int], optional + List of label IDs to extract centroids for. If None, extracts all non-zero labels. + + Returns + ------- + dict[int, np.ndarray] + If label_ids is None, returns a dict mapping label IDs to their centroids (x,y,z) in RAS coordinates. + If label_ids is provided, returns a tuple containing: + - dict[int, np.ndarray]: Mapping of found label IDs to their centroids. + - list[int]: List of label IDs that were not found in the image. """ # Get segmentation data and affine seg_data = seg_img.get_fdata() @@ -109,40 +109,48 @@ def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: list[int] | None -def save_nifti_background(io_processes, data, affine, header, filepath): +def save_nifti_background( + io_processes: list, + data: np.ndarray, + affine: np.ndarray, + header: nib.Nifti1Header, + filepath: str | Path +) -> None: """Save a NIfTI image in a background process. - + Creates a MGHImage from the provided data and metadata, then saves it to disk using a background process to avoid blocking the main execution. - - Args: - io_processes (list): List to store background process handles - data (np.ndarray): Image data array - affine (np.ndarray): 4x4 affine transformation matrix - header: NIfTI header object containing metadata - filepath (str): Path where the image should be saved + + Parameters + ---------- + io_processes : list + List to store background process handles. + data : np.ndarray + Image data array. + affine : np.ndarray + 4x4 affine transformation matrix. + header : nib.Nifti1Header + NIfTI header object containing metadata. + filepath : str or Path + Path where the image should be saved. """ logger.info(f"Saving NIfTI image to {filepath}") io_processes.append(run_in_background(nib.save, False, nib.MGHImage(data, affine, header), filepath)) -def convert_numpy_to_json_serializable(obj): - """Convert numpy arrays in nested data structures to JSON serializable format. - - Recursively traverses dictionaries, lists, and numpy arrays, converting numpy arrays - to Python lists and numpy scalars to Python scalars for JSON serialization. - - Args: - obj: Any Python object that may contain numpy arrays (dict, list, np.ndarray, or scalar) - - Returns: - The input object with all numpy arrays converted to lists and numpy scalars to Python scalars - - Example: - >>> data = {'array': np.array([1, 2, 3]), 'nested': {'array': np.array([4, 5])}} - >>> result = convert_numpy_to_json_serializable(data) - >>> # Result: {'array': [1, 2, 3], 'nested': {'array': [4, 5]}} +def convert_numpy_to_json_serializable(obj: object) -> object: + """Convert numpy types to JSON serializable types. + + Parameters + ---------- + obj : object + Object to convert to JSON serializable type. + + Returns + ------- + object + JSON serializable version of the input object. """ if isinstance(obj, dict): return {k: convert_numpy_to_json_serializable(v) for k, v in obj.items()} @@ -157,24 +165,19 @@ def convert_numpy_to_json_serializable(obj): return obj -def load_fsaverage_centroids(centroids_path): +def load_fsaverage_centroids(centroids_path: str | Path) -> dict[int, np.ndarray]: """Load fsaverage centroids from static JSON file. - - Loads pre-computed centroids from a static JSON file, avoiding the need to - compute them from the fsaverage segmentation at runtime. - - Args: - centroids_path (str or Path): Path to the JSON file containing centroids - - Returns: - dict[int, np.ndarray]: Mapping of label IDs to their centroids (x,y,z) in RAS coordinates - - Raises: - FileNotFoundError: If the centroids file doesn't exist - json.JSONDecodeError: If the file is not valid JSON + + Parameters + ---------- + centroids_path : str or Path + Path to the JSON file containing centroids. + + Returns + ------- + dict[int, np.ndarray] + Dictionary mapping label IDs to their centroids in RAS coordinates. """ - import json - from pathlib import Path centroids_path = Path(centroids_path) if not centroids_path.exists(): @@ -192,23 +195,19 @@ def load_fsaverage_centroids(centroids_path): return centroids -def load_fsaverage_affine(affine_path): +def load_fsaverage_affine(affine_path: str | Path) -> np.ndarray: """Load fsaverage affine matrix from static text file. - - Loads pre-computed affine matrix from a static text file, avoiding the need to - load the fsaverage segmentation at runtime. - - Args: - affine_path (str or Path): Path to the text file containing affine matrix - - Returns: - np.ndarray: 4x4 affine transformation matrix - - Raises: - FileNotFoundError: If the affine file doesn't exist - ValueError: If the file doesn't contain a valid 4x4 matrix + + Parameters + ---------- + affine_path : str or Path + Path to the text file containing affine matrix. + + Returns + ------- + np.ndarray + 4x4 affine transformation matrix. """ - from pathlib import Path affine_path = Path(affine_path) if not affine_path.exists(): @@ -222,31 +221,41 @@ def load_fsaverage_affine(affine_path): return affine_matrix -def load_fsaverage_data(data_path): +def load_fsaverage_data(data_path: str | Path) -> tuple[np.ndarray, dict, np.ndarray]: """Load fsaverage affine matrix and header fields from static JSON file. - - Loads pre-computed affine matrix and header fields from a static JSON file, - avoiding the need to load the fsaverage segmentation at runtime. - - Args: - data_path (str or Path): Path to the JSON file containing combined data - - Returns: - tuple: Contains: - - affine_matrix (np.ndarray): 4x4 affine transformation matrix - - header_fields (dict): Header fields needed for LTA: - - dims (list[int]): Volume dimensions [x,y,z] - - delta (list[float]): Voxel size in mm [x,y,z] - - Mdc (np.ndarray): 3x3 direction cosines matrix - - Pxyz_c (np.ndarray): RAS center coordinates [x,y,z] + + Parameters + ---------- + data_path : str or Path + Path to the JSON file containing combined data. + + Returns + ------- + affine_matrix : np.ndarray + 4x4 affine transformation matrix. + header_fields : dict + Header fields needed for LTA: + - dims : list[int] + Volume dimensions [x,y,z]. + - delta : list[float] + Voxel size in mm [x,y,z]. + - Mdc : np.ndarray + 3x3 direction cosines matrix. + - Pxyz_c : np.ndarray + RAS center coordinates [x,y,z]. + vox2ras_tkr : np.ndarray + Voxel to RAS tkr-space transformation matrix. + + Raises + ------ + FileNotFoundError + If the data file doesn't exist. + json.JSONDecodeError + If the file is not valid JSON. + ValueError + If required fields are missing. - Raises: - FileNotFoundError: If the data file doesn't exist - json.JSONDecodeError: If the file is not valid JSON - ValueError: If required fields are missing """ - import json - from pathlib import Path data_path = Path(data_path) if not data_path.exists(): diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 70fbc00d..e0354cb0 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -26,8 +26,8 @@ ) from CorpusCallosum.localization import localization_inference from CorpusCallosum.registration.mapping_helpers import ( - apply_transform_and_map_volume, apply_transform_to_pt, + apply_transform_to_volume, get_mapping_to_standard_space, interpolate_midplane, map_softlabels_to_orig, @@ -255,23 +255,39 @@ def options_parse() -> argparse.Namespace: return args -def centroid_registration(aseg_nib, verbose=False): +def centroid_registration(aseg_nib: nib.Nifti1Image, verbose: bool = False) -> tuple[ + np.ndarray, np.ndarray, np.ndarray, nib.Nifti1Header, np.ndarray +]: """Perform centroid-based registration between subject and fsaverage space. Computes a rigid transformation between the subject's segmentation and fsaverage space by aligning centroids of corresponding anatomical structures. - Args: - aseg_nib (nib.Nifti1Image): Subject's segmentation image - verbose (bool): Whether to print progress information - - Returns: - tuple: Contains: - - orig_fsaverage_vox2vox: Transformation matrix from original to fsaverage voxel space - - orig_fsaverage_ras2ras: Transformation matrix from original to fsaverage RAS space - - fsaverage_hires_affine: High-resolution fsaverage affine matrix - - fsaverage_header: FSAverage header fields for LTA writing - + Parameters + ---------- + aseg_nib : nibabel.Nifti1Image + Subject's segmentation image. + verbose : bool, optional + Whether to print progress information, by default False. + + Returns + ------- + orig_fsaverage_vox2vox : np.ndarray + Transformation matrix from original to fsaverage voxel space. + orig_fsaverage_ras2ras : np.ndarray + Transformation matrix from original to fsaverage RAS space. + fsaverage_hires_affine : np.ndarray + High-resolution fsaverage affine matrix. + fsaverage_header : nibabel.Nifti1Header + FSAverage header fields for LTA writing. + vox2ras_tkr : np.ndarray + Voxel to RAS tkr-space transformation matrix. + + Notes + ----- + The function uses pre-computed fsaverage centroids and data from static files + to perform the registration. It matches corresponding anatomical structures + between the subject's segmentation and fsaverage space. """ if verbose: print("Centroid registration") @@ -306,25 +322,37 @@ def centroid_registration(aseg_nib, verbose=False): return orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, fsaverage_hires_affine, fsaverage_header, vox2ras_tkr -def localize_ac_pc(midslices, aseg_nib, orig_fsaverage_vox2vox, model_localization, slices_to_analyze): +def localize_ac_pc( + midslices: np.ndarray, + aseg_nib: "nib.Nifti1Image", + orig_fsaverage_vox2vox: np.ndarray, + model_localization: "torch.nn.Module", + slices_to_analyze: int +) -> tuple[np.ndarray, np.ndarray]: """Localize anterior and posterior commissure points in the brain. Uses a trained model to detect AC and PC points in mid-sagittal slices, using the third ventricle as an anatomical reference. - Args: - midslices (np.ndarray): Array of mid-sagittal slices - aseg_nib (nib.Nifti1Image): Subject's segmentation image - orig_fsaverage_vox2vox (np.ndarray): Transformation matrix to fsaverage space - fsaverage_hires_affine (np.ndarray): High-resolution fsaverage affine matrix - model_localization: Trained model for AC-PC detection - slices_to_analyze (int): Number of slices to process - - Returns: - tuple: Contains: - - ac_coords (np.ndarray): Coordinates of the anterior commissure - - pc_coords (np.ndarray): Coordinates of the posterior commissure - + Parameters + ---------- + midslices : np.ndarray + Array of mid-sagittal slices. + aseg_nib : nibabel.Nifti1Image + Subject's segmentation image. + orig_fsaverage_vox2vox : np.ndarray + Transformation matrix to fsaverage space. + model_localization : torch.nn.Module + Trained model for AC-PC detection. + slices_to_analyze : int + Number of slices to process. + + Returns + ------- + ac_coords : np.ndarray + Coordinates of the anterior commissure. + pc_coords : np.ndarray + Coordinates of the posterior commissure. """ # get center of third ventricle from aseg and map to fsaverage space @@ -344,28 +372,40 @@ def localize_ac_pc(midslices, aseg_nib, orig_fsaverage_vox2vox, model_localizati return ac_coords, pc_coords -def segment_cc(midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, slices_to_analyze): +def segment_cc( + midslices: np.ndarray, + ac_coords: np.ndarray, + pc_coords: np.ndarray, + aseg_nib: "nib.Nifti1Image", + model_segmentation: "torch.nn.Module", + slices_to_analyze: int +) -> tuple[np.ndarray, np.ndarray]: """Segment the corpus callosum using a trained model. Performs corpus callosum segmentation on mid-sagittal slices using a trained model, - with AC-PC points as anatomical references. Includes post-processing to clean the segmentation. - - Args: - midslices (np.ndarray): Array of mid-sagittal slices - ac_coords (np.ndarray): Anterior commissure coordinates - pc_coords (np.ndarray): Posterior commissure coordinates - aseg_nib (nib.Nifti1Image): Subject's segmentation image - orig_fsaverage_vox2vox (np.ndarray): Transformation matrix to fsaverage space - fsaverage_hires_affine (np.ndarray): High-resolution fsaverage affine matrix - model_segmentation: Trained model for CC segmentation - slices_to_analyze (int): Number of slices to process - verbose (bool): Whether to print progress information - - Returns: - tuple: Contains: - - segmentation (np.ndarray): Binary segmentation of the corpus callosum - - outputs_soft (np.ndarray): Soft segmentation probabilities - + with AC-PC points as anatomical references. Includes post-processing to clean + the segmentation. + + Parameters + ---------- + midslices : np.ndarray + Array of mid-sagittal slices. + ac_coords : np.ndarray + Anterior commissure coordinates. + pc_coords : np.ndarray + Posterior commissure coordinates. + aseg_nib : nibabel.Nifti1Image + Subject's segmentation image. + model_segmentation : torch.nn.Module + Trained model for CC segmentation. + slices_to_analyze : int + Number of slices to process. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + - segmentation : Binary segmentation of the corpus callosum. + - outputs_soft : Soft segmentation probabilities. """ # get 5 mm of slices output with 9 slices per inference midslices_middle = midslices.shape[0] // 2 @@ -414,7 +454,6 @@ def main( contour_smoothing: float = 5, save_template: str | Path | None = None, cpu: bool = False, - # output paths upright_volume_path: str | Path = None, segmentation_path: str | Path = None, postproc_results_path: str | Path = None, @@ -434,80 +473,85 @@ def main( ) -> None: """Main pipeline function for corpus callosum analysis. - This function performs the following steps: - 1. Initializes environment and loads models - 2. Registers input image to fsaverage space - 3. Detects AC and PC points - 4. Segments the corpus callosum - 5. Performs enhanced post-processing analysis - 6. Saves results and visualizations - + This function performs the complete corpus callosum analysis pipeline including + registration, landmark detection, segmentation, and morphometry analysis. + + Parameters + ---------- + in_mri_path : str or Path + Path to input MRI file. + aseg_path : str or Path + Path to input segmentation file. + output_dir : str or Path + Directory for output files. + slice_selection : str, optional + Which slices to process ('middle', 'all', or specific slice number), by default 'middle'. + debug_output_dir : str or Path, optional + Directory for debug outputs, by default None. + verbose : bool, optional + Flag for verbose output, by default False. + num_thickness_points : int, optional + Number of points for thickness estimation, by default 100. + subdivisions : list[float], optional + List of subdivision fractions for CC subsegmentation, by default None. + subdivision_method : str, optional + Method for contour subdivision ('shape', 'vertical', 'angular', 'eigenvector'), by default 'shape'. + contour_smoothing : float, optional + Gaussian sigma for smoothing during contour detection, by default 5. + save_template : str or Path, optional + Directory path where to save contours.txt and thickness_values.txt files, by default None. + cpu : bool, optional + Force CPU usage even when CUDA is available, by default False. + upright_volume_path : str or Path, optional + Path to save upright volume, by default None. + segmentation_path : str or Path, optional + Path to save segmentation, by default None. + postproc_results_path : str or Path, optional + Path to save post-processing results, by default None. + cc_markers_path : str or Path, optional + Path to save CC markers, by default None. + upright_lta_path : str or Path, optional + Path to save upright LTA transform, by default None. + orient_volume_lta_path : str or Path, optional + Path to save orientation transform, by default None. + surf_file_path : str or Path, optional + Path to save surface file, by default None. + overlay_file_path : str or Path, optional + Path to save overlay file, by default None. + cc_html_path : str or Path, optional + Path to save HTML visualization, by default None. + vtk_file_path : str or Path, optional + Path to save VTK file, by default None. + orig_space_segmentation_path : str or Path, optional + Path to save segmentation in original space, by default None. + debug_image_path : str or Path, optional + Path to save debug images, by default None. + thickness_image_path : str or Path, optional + Path to save thickness visualization, by default None. + softlabels_cc_path : str or Path, optional + Path to save CC soft labels, by default None. + softlabels_fn_path : str or Path, optional + Path to save fornix soft labels, by default None. + softlabels_background_path : str or Path, optional + Path to save background soft labels, by default None. + + Notes + ----- The function saves multiple outputs to specified paths or default locations in output_dir: - - cc_markers.json: Contains detected landmarks and measurements - - midplane_slices.mgz: Extracted midplane slices - - upright_volume.mgz: Volume aligned to standard orientation - - segmentation.mgz: Corpus callosum segmentation - - cc_postproc_results.json: Enhanced postprocessing results - - Various visualization plots and transformation matrices - - Args: - in_mri_path: - Path to input MRI file - aseg_path: - Path to input segmentation file - output_dir: - Directory for output files - slice_selection: - Which slices to process ('middle', 'all', or specific slice number) - debug_output_dir: - Optional directory for debug outputs - verbose: - Flag for verbose output - num_thickness_points: - Number of points for thickness estimation - subdivisions: - List of subdivision fractions for CC subsegmentation - subdivision_method: - Method for contour subdivision - contour_smoothing: - Gaussian sigma for smoothing during contour detection - save_template: - Directory path where to save contours.txt and thickness_values.txt files - cpu: - Force CPU usage even when CUDA is available - upright_volume_path: - Path for upright volume output (default: output_dir/upright_volume.mgz) - segmentation_path: - Path for segmentation output (default: output_dir/segmentation.mgz) - postproc_results_path: - Path for postprocessing results (default: output_dir/cc_postproc_results.json) - cc_markers_path: - Path for CC markers output (default: output_dir/cc_markers.json) - upright_lta_path: - Path for upright LTA transform (default: output_dir/upright.lta) - orient_volume_lta_path: - Path for orientation volume LTA transform (default: output_dir/orient_volume.lta) - surf_file_path: - Path for surf file (default: output_dir/surf/callosum.surf) - overlay_file_path: - Path for overlay file (default: output_dir/mri/callosum_seg_aseg_space.mgz) - cc_html_path: - Path for CC HTML file (default: output_dir/qc_snapshots/corpus_callosum.html) - vtk_file_path: - Path for vtk file (default: output_dir/surf/callosum_mesh.vtk) - orig_space_segmentation_path: - Path for segmentation in original space (default: output_dir/mri/segmentation_orig_space.mgz) - debug_image_path: - Path for debug visualization image (default: output_dir/stats/cc_postprocessing.png) - thickness_image_path: - Path for thickness image (default: output_dir/qc_snapshots/corpus_callosum_thickness_3d.png) - softlabels_cc_path: - Path for cc softlabels (default: output_dir/mri/callosum_seg_soft.mgz) - softlabels_fn_path: - Path for fornix softlabels (default: output_dir/mri/fornix_seg_soft.mgz) - softlabels_background_path: - Path for background softlabels (default: output_dir/mri/background_seg_soft.mgz) - + - cc_markers.json: Contains detected landmarks and measurements. + - midplane_slices.mgz: Extracted midplane slices. + - upright_volume.mgz: Volume aligned to standard orientation. + - segmentation.mgz: Corpus callosum segmentation. + - cc_postproc_results.json: Enhanced postprocessing results. + - Various visualization plots and transformation matrices. + + The pipeline consists of the following steps: + 1. Initializes environment and loads models. + 2. Registers input image to fsaverage space. + 3. Detects AC and PC points. + 4. Segments the corpus callosum. + 5. Performs enhanced post-processing analysis. + 6. Saves results and visualizations. """ if subdivisions is None: @@ -586,7 +630,7 @@ def main( # start saving upright volume IO_processes.append( run_in_background( - apply_transform_and_map_volume, + apply_transform_to_volume, False, orig.get_fdata(), orig_fsaverage_vox2vox, @@ -670,7 +714,9 @@ def main( middle_slice_result = slice_results[len(slice_results) // 2] if len(middle_slice_result['split_contours']) <= 5: - subdivision_mask = make_subdivision_mask(segmentation.shape[1:], middle_slice_result['split_contours']) + subdivision_mask = make_subdivision_mask(segmentation.shape[1:], + middle_slice_result['split_contours'], + orig.header.get_zooms()) else: logger.warning("Too many subsegments for lookup table, skipping sub-divion of output segmentation.") subdivision_mask = None @@ -754,7 +800,6 @@ def main( voxel_size=orig.header.get_zooms() ) cc_volume_contour = segmentation_postprocessing.get_cc_volume_contour( - desired_width_mm=5, cc_contours=outer_contours, voxel_size=orig.header.get_zooms() ) @@ -803,7 +848,7 @@ def main( ac_coords_3d = np.hstack((FSAVERAGE_MIDDLE, ac_coords)) pc_coords_3d = np.hstack((FSAVERAGE_MIDDLE, pc_coords)) standardized_to_orig_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig = ( - get_mapping_to_standard_space(orig, ac_coords_3d, pc_coords_3d, orig_fsaverage_vox2vox, output_dir) + get_mapping_to_standard_space(orig, ac_coords_3d, pc_coords_3d, orig_fsaverage_vox2vox) ) diff --git a/CorpusCallosum/localization/localization_inference.py b/CorpusCallosum/localization/localization_inference.py index e90d1647..6527acdc 100644 --- a/CorpusCallosum/localization/localization_inference.py +++ b/CorpusCallosum/localization/localization_inference.py @@ -25,16 +25,23 @@ from FastSurferCNN.download_checkpoints import main as download_checkpoints -def load_model(checkpoint_path, device=None): - """ - Load the trained numerical localization model from checkpoint - - Args: - checkpoint_path: Path to model checkpoint - device: torch device to load model to (defaults to CUDA if available) - - Returns: - model: Loaded model +def load_model(checkpoint_path: str | Path | None = None, + device: torch.device | None = None) -> DenseNet: + """Load trained numerical localization model from checkpoint. + + Parameters + ---------- + checkpoint_path : str or Path or None, optional + Path to model checkpoint, by default None. + If None, downloads and uses default checkpoint. + device : torch.device or None, optional + Device to load model to, by default None. + If None, uses CUDA if available, else CPU. + + Returns + ------- + DenseNet + Loaded and initialized model in evaluation mode """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -60,7 +67,6 @@ def load_model(checkpoint_path, device=None): ) checkpoint_path = cc_config['localization'] - # Load state dict if isinstance(checkpoint_path, str) or isinstance(checkpoint_path, Path): state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True) @@ -68,23 +74,24 @@ def load_model(checkpoint_path, device=None): state_dict = state_dict['model_state_dict'] else: state_dict = checkpoint_path - model.load_state_dict(state_dict) model = model.to(device) model.eval() return model -def get_transforms(): - """Get preprocessing transforms for inference""" + +def get_transforms() -> transforms.Compose: + """Get preprocessing transforms for inference. + + Returns + ------- + transforms.Compose + Composed transform pipeline including: + - Intensity scaling to [0,1] + - Fixed size cropping around AC-PC points + """ tr = [ - # transforms.LoadImaged( - # keys=['image'], - # reader="NibabelReader", - # image_only=True, - # dtype=torch.float32, - # ensure_channel_first=True - # ), transforms.ScaleIntensityd(keys=['image'], minv=0, maxv=1), CropAroundACPCFixedSize( keys=['image'], @@ -94,16 +101,28 @@ def get_transforms(): ] return transforms.Compose(tr) -def preprocess_volume(image_volume, center_pt, transform=None): - """ - Preprocess a volume for inference - - Args: - image_volume: Input volume as numpy array or path to nifti file - transform: Optional custom transform pipeline - - Returns: - preprocessed: Preprocessed image tensor ready for model input + +def preprocess_volume( + image_volume: np.ndarray, + center_pt: np.ndarray, + transform: transforms.Transform | None = None +) -> dict[str, torch.Tensor]: + """Preprocess a volume for inference. + + Parameters + ---------- + image_volume : np.ndarray + Input image volume + center_pt : np.ndarray + Center point coordinates for cropping + transform : transforms.Transform or None, optional + Custom transform pipeline, by default None. + If None, uses default transforms from get_transforms(). + + Returns + ------- + dict[str, torch.Tensor] + Dictionary containing preprocessed image tensor """ if transform is None: transform = get_transforms() @@ -120,18 +139,36 @@ def preprocess_volume(image_volume, center_pt, transform=None): return transformed -def run_inference(model, image_volume, third_ventricle_center, device=None, transform=None): +def run_inference(model: torch.nn.Module, + image_volume: np.ndarray, + third_ventricle_center: np.ndarray, + device: torch.device | None = None, + transform: transforms.Transform | None = None + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, tuple[int, int]]: """ Run inference on an image volume - Args: - model: Trained model - image_volume: Input volume as numpy array or path to nifti file - device: torch device to run inference on - transform: Optional custom transform pipeline - - Returns: - dict containing predicted AC and PC coordinates in original image space + Parameters + ---------- + model : torch.nn.Module + Trained model for inference + image_volume : np.ndarray + Input volume as numpy array + third_ventricle_center : np.ndarray + Initial center point estimate for cropping + device : torch.device, optional + Device to run inference on, by default None + transform : transforms.Transform, optional + Custom transform pipeline, by default None + + Returns + ------- + tuple + Tuple containing: + - np.ndarray: Predicted PC coordinates + - np.ndarray: Predicted AC coordinates + - np.ndarray: Processed input images + - tuple: Crop offsets (left, top) """ if device is None: device = next(model.parameters()).device @@ -183,12 +220,34 @@ def run_inference(model, image_volume, third_ventricle_center, device=None, tran (t_dict['crop_left'], t_dict['crop_top'])) -def run_inference_on_slice(model, image_slice, center_pt, debug_output=None): +def run_inference_on_slice(model: torch.nn.Module, + image_slice: np.ndarray, + center_pt: np.ndarray, + debug_output: str | None = None) -> tuple[np.ndarray, np.ndarray]: + """Run inference on a single slice to detect AC and PC points. + + Parameters + ---------- + model : torch.nn.Module + Trained model for AC-PC detection + image_slice : np.ndarray + 3D image slice to run inference on + center_pt : np.ndarray + Initial center point estimate for cropping + debug_output : str, optional + Path to save debug visualization, by default None + + Returns + ------- + tuple[np.ndarray, np.ndarray] + Detected AC and PC coordinates as (ac_coords, pc_coords) + Each coordinate array has shape (2,) containing [y,x] positions + """ # Run inference - pc_coords, ac_coords, inputs, (crop_left, crop_top) = run_inference(model, image_slice, center_pt) + pc_coords, ac_coords, _, (crop_left, crop_top) = run_inference(model, image_slice, center_pt) center_pt = np.mean(np.concatenate([ac_coords, pc_coords], axis=0), axis=0) - pc_coords, ac_coords, inputs, (crop_left, crop_top) = run_inference(model, image_slice, center_pt) + pc_coords, ac_coords, _, (crop_left, crop_top) = run_inference(model, image_slice, center_pt) pc_coords = np.mean(pc_coords, axis=0) ac_coords = np.mean(ac_coords, axis=0) diff --git a/CorpusCallosum/paint_cc_into_pred.py b/CorpusCallosum/paint_cc_into_pred.py index ec649a86..97cdf74e 100644 --- a/CorpusCallosum/paint_cc_into_pred.py +++ b/CorpusCallosum/paint_cc_into_pred.py @@ -44,13 +44,7 @@ def argument_parse(): - """ - Create a command line interface and return command line options. - - Returns - ------- - options : argparse.Namespace - Namespace object holding options. + """Create a command line interface and return command line options. """ parser = argparse.ArgumentParser(usage=HELPTEXT) parser.add_argument( @@ -80,23 +74,25 @@ def argument_parse(): return args -def paint_in_cc(pred: npt.ArrayLike, aseg_cc: npt.ArrayLike) -> npt.ArrayLike: - """ - Paint corpus callosum segmentation into aseg+dkt segmentation map. - - Note, that this function modifies the original array and does not create a copy. +def paint_in_cc(pred: npt.NDArray[np.int_], aseg_cc: npt.NDArray[np.int_]) -> npt.NDArray[np.int_]: + """Paint corpus callosum segmentation into aseg+dkt segmentation map. Parameters ---------- - asegdkt : npt.ArrayLike + pred : npt.NDArray[np.int_] Deep-learning segmentation map. - aseg_cc : npt.ArrayLike + aseg_cc : npt.NDArray[np.int_] Aseg segmentation with CC. Returns ------- - asegdkt + npt.NDArray[np.int_] Segmentation map with added CC. + + Notes + ----- + This function modifies the original array and does not create a copy. + The CC labels (251-255) from aseg_cc are copied into pred. """ cc_mask = (aseg_cc >= 251) & (aseg_cc <= 255) pred[cc_mask] = aseg_cc[cc_mask] diff --git a/CorpusCallosum/registration/mapping_helpers.py b/CorpusCallosum/registration/mapping_helpers.py index a3861ccd..f53381d2 100644 --- a/CorpusCallosum/registration/mapping_helpers.py +++ b/CorpusCallosum/registration/mapping_helpers.py @@ -1,5 +1,8 @@ +from pathlib import Path + import nibabel as nib import numpy as np +import SimpleITK as sitk from scipy.ndimage import affine_transform import FastSurferCNN.utils.logging as logging @@ -7,17 +10,23 @@ logger = logging.get_logger(__name__) -def make_midplane_affine(orig_affine, slices_to_analyze=1, offset=4): - """ - Creates an affine transformation matrix for midplane slices. +def make_midplane_affine(orig_affine: np.ndarray, slices_to_analyze: int = 1, + offset: int = 4) -> np.ndarray: + """Create affine transformation matrix for midplane slices. - Args: - orig_affine: Original image affine matrix - slices_to_analyze: Number of slices to analyze around midplane (default=1) - offset: Additional offset in x direction (default=4) + Parameters + ---------- + orig_affine : np.ndarray + Original image affine matrix (4x4) + slices_to_analyze : int, optional + Number of slices to analyze around midplane, by default 1 + offset : int, optional + Additional offset in x direction, by default 4 - Returns: - seg_affine: Affine matrix for midplane slices + Returns + ------- + np.ndarray + 4x4 affine matrix for midplane slices """ # Create translation matrix to center on midplane orig_to_seg = np.eye(4) @@ -29,16 +38,23 @@ def make_midplane_affine(orig_affine, slices_to_analyze=1, offset=4): return seg_affine -def correct_nodding(ac_pt, pc_pt): - """ - Calculates rotation matrix to correct for head nodding based on AC-PC line orientation. +def correct_nodding(ac_pt: np.ndarray, pc_pt: np.ndarray) -> np.ndarray: + """Calculate rotation matrix to correct head nodding. - Args: - ac_pt: Coordinates of the anterior commissure point - pc_pt: Coordinates of the posterior commissure point + Calculates rotation matrix to align AC-PC line with posterior direction, + correcting for head nodding based on AC-PC line orientation. - Returns: - rotation_matrix: 3x3 rotation matrix to align AC-PC line with posterior direction + Parameters + ---------- + ac_pt : np.ndarray + Coordinates of the anterior commissure point + pc_pt : np.ndarray + Coordinates of the posterior commissure point + + Returns + ------- + np.ndarray + 3x3 rotation matrix to align AC-PC line with posterior direction """ ac_pc_vec = pc_pt - ac_pt ac_pc_dist = np.linalg.norm(ac_pc_vec) @@ -74,17 +90,22 @@ def correct_nodding(ac_pt, pc_pt): return rotation_matrix -def apply_transform_to_pt(pts, T, inv=False): - """ - Applies an homoegenous 4x4 transformation matrix to a point. +def apply_transform_to_pt(pts: np.ndarray, T: np.ndarray, inv: bool = False) -> np.ndarray: + """Apply homogeneous transformation matrix to points. - Args: - pts: Point coordinates to transform - T: Transformation matrix - inv: If True, applies inverse of transformation (default=False) + Parameters + ---------- + pts : np.ndarray + Point coordinates to transform, shape (3,) or (3, N) + T : np.ndarray + 4x4 homogeneous transformation matrix + inv : bool, optional + If True, applies inverse of transformation, by default False - Returns: - Transformed point coordinates + Returns + ------- + np.ndarray + Transformed point coordinates, shape (3,) or (3, N) """ if inv: T = T.copy() @@ -97,21 +118,35 @@ def apply_transform_to_pt(pts, T, inv=False): def get_mapping_to_standard_space( - orig, ac_coords_3d, pc_coords_3d, orig_fsaverage_vox2vox, output_dir -): - """ - Maps an image to standard space using AC-PC alignment. - - Args: - orig: Original image - ac_coords_3d: 3D coordinates of anterior commissure - pc_coords_3d: 3D coordinates of posterior commissure - orig_fsaverage_vox2vox: Original to fsaverage space transformation matrix - output_dir: Directory for output files - - Returns: - tuple: (transformation matrix, AC coords standardized, PC coords standardized, - AC coords original, PC coords original) + orig: "nib.Nifti1Image", + ac_coords_3d: np.ndarray, + pc_coords_3d: np.ndarray, + orig_fsaverage_vox2vox: np.ndarray, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Get transformations to map image to standard space. + + Parameters + ---------- + orig : nib.Nifti1Image + Original image + ac_coords_3d : np.ndarray + AC coordinates in 3D space + pc_coords_3d : np.ndarray + PC coordinates in 3D space + orig_fsaverage_vox2vox : np.ndarray + Transformation matrix from original to fsaverage space + output_dir : str or Path + Directory to save transformation files + + Returns + ------- + tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray] + Contains: + - upright_volume : Upright transformed volume + - standardized_volume : Volume in standard space + - ac_coords_standardized : AC coordinates in standard space + - pc_coords_standardized : PC coordinates in standard space + - standardized_affine : Affine matrix for standard space """ image_center = np.array(orig.shape) / 2 @@ -168,21 +203,43 @@ def get_mapping_to_standard_space( ) -def apply_transform_and_map_volume( - volume, transform, affine, header, output_path=None, order=3, output_size=None -): - """ - Applies transformation to a volume and saves the result. - - Args: - volume: Input volume data - transform: Transformation matrix to apply - affine: Affine matrix for the output image - header: Header for the output image - output_path: Path to save transformed volume - - Returns: - transformed: Transformed volume data +def apply_transform_to_volume( + volume: np.ndarray, + transform: np.ndarray, + affine: np.ndarray, + header: nib.freesurfer.mghformat.MGHHeader, + output_path: str | Path | None = None, + output_size: np.ndarray | None = None, + order: int = 1 +) -> np.ndarray: + """Apply transformation to a volume and save the result. + + Parameters + ---------- + volume : np.ndarray + Input volume data + transform : np.ndarray + Transformation matrix to apply + affine : np.ndarray + Affine matrix for the output image + header : nib.freesurfer.mghformat.MGHHeader + Header for the output image + output_path : str or Path or None, optional + Path to save transformed volume, by default None + output_size : np.ndarray or None, optional + Size of output volume, by default None (uses input size) + order : int, optional + Order of interpolation, by default 1 + + Returns + ------- + np.ndarray + Transformed volume data + + Notes + ----- + Uses scipy.ndimage.affine_transform for the transformation. + If output_path is provided, saves the result as a MGH file. """ if output_size is None: @@ -199,15 +256,25 @@ def apply_transform_and_map_volume( return transformed -def make_affine(simpleITKImage): - """ - Creates an affine transformation matrix from a SimpleITK image. +def make_affine(simpleITKImage: 'sitk.Image') -> np.ndarray: + """Create an affine transformation matrix from a SimpleITK image. + + Parameters + ---------- + simpleITKImage : sitk.Image + Input SimpleITK image - Args: - simpleITKImage: Input SimpleITK image + Returns + ------- + np.ndarray + 4x4 affine transformation matrix in RAS coordinates - Returns: - affine: 4x4 affine transformation matrix in RAS coordinates + Notes + ----- + The function: + 1. Gets affine transform in LPS coordinates + 2. Converts to RAS coordinates to match nibabel + 3. Returns the final 4x4 transformation matrix """ # get affine transform in LPS c = [ @@ -226,27 +293,47 @@ def make_affine(simpleITKImage): def map_softlabels_to_orig( - outputs_soft, - orig_fsaverage_vox2vox, - orig, - slices_to_analyze, - orig_space_segmentation_path=None, - fsaverage_middle=128, - subdivision_mask=None, -): - """ - Maps soft labels back to original image space and applies post-processing. - - # TODO: this could by padding after the transform - - Args: - outputs_soft: Soft label predictions - orig_fsaverage_vox2vox: Original to fsaverage space transformation - orig: Original image - slices_to_analyze: Number of slices to analyze - - Returns: - segmentation_orig_space: Final segmentation in original image space + outputs_soft: np.ndarray, + orig_fsaverage_vox2vox: np.ndarray, + orig: np.ndarray, + slices_to_analyze: int, + orig_space_segmentation_path: str | Path | None = None, + fsaverage_middle: int = 128, + subdivision_mask: np.ndarray | None = None +) -> np.ndarray: + """Map soft labels back to original image space and apply post-processing. + + Parameters + ---------- + outputs_soft : np.ndarray + Soft label predictions + orig_fsaverage_vox2vox : np.ndarray + Original to fsaverage space transformation + orig : np.ndarray + Original image + slices_to_analyze : int + Number of slices to analyze + orig_space_segmentation_path : str or Path or None, optional + Path to save segmentation in original space, by default None + fsaverage_middle : int, optional + Middle slice index in fsaverage space, by default 128 + subdivision_mask : np.ndarray or None, optional + Mask for subdividing regions, by default None + + Returns + ------- + np.ndarray + Final segmentation in original image space + + Notes + ----- + The function: + 1. Pads soft labels to original image size + 2. Transforms each label channel separately + 3. Applies post-processing if needed + 4. Optionally saves result to file + + TODO: This could be optimized by padding after the transform """ # map softlabels to original image @@ -315,17 +402,25 @@ def map_softlabels_to_orig( return segmentation_orig_space -def interpolate_midplane(orig, orig_fsaverage_vox2vox, slices_to_analyze): - """ - Interpolates image data at the midplane using a grid of points. - - Args: - orig: Original image - orig_fsaverage_vox2vox: Original to fsaverage space transformation - slices_to_analyze: Number of slices to analyze - - Returns: - transformed: Interpolated image data at midplane +def interpolate_midplane( + orig: nib.Nifti1Image, + orig_fsaverage_vox2vox: np.ndarray, + slices_to_analyze: int) -> np.ndarray: + """Interpolates image data at the midplane using a grid of points. + + Parameters + ---------- + orig : nib.Nifti1Image + Original image + orig_fsaverage_vox2vox : np.ndarray + Original to fsaverage space transformation matrix + slices_to_analyze : int + Number of slices to analyze around midplane + + Returns + ------- + np.ndarray + Interpolated image data at midplane """ # slice_thickness = 9+slices_to_analyze-1 diff --git a/CorpusCallosum/segmentation/segmentation_inference.py b/CorpusCallosum/segmentation/segmentation_inference.py index 0e95ec4d..a013fdc7 100644 --- a/CorpusCallosum/segmentation/segmentation_inference.py +++ b/CorpusCallosum/segmentation/segmentation_inference.py @@ -24,20 +24,25 @@ from FastSurferCNN.models.networks import FastSurferVINN -def load_model(checkpoint_path, device=None): - """ - Load the trained model from checkpoint - - Args: - checkpoint_path: Path to model checkpoint - device: torch device to load model to (defaults to CUDA if available) - - Returns: - model: Loaded model +def load_model(checkpoint_path: str | None = None, device: torch.device | None = None) -> FastSurferVINN: + """Load trained model from checkpoint. + + Parameters + ---------- + checkpoint_path : str or None, optional + Path to model checkpoint, by default None. + If None, downloads and uses default checkpoint. + device : torch.device or None, optional + Device to load model to, by default None. + If None, uses CUDA if available, else CPU. + + Returns + ------- + FastSurferVINN + Loaded and initialized model in evaluation mode """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - params = { "num_classes": 3, @@ -67,27 +72,48 @@ def load_model(checkpoint_path, device=None): ) checkpoint_path = cc_config['segmentation'] - #model = torch.load(checkpoint_path, map_location=device, weights_only=False) weights = torch.load(checkpoint_path, weights_only=True, map_location=device) model.load_state_dict(weights) model.eval() model.to(device) return model -def run_inference(model, image_slice, AC_center, PC_center, voxel_size, device=None, transform=None): - """ - Run inference on a single image slice - - Args: - model: Trained model - image_slice: Input image as numpy array - device: torch device to run inference on - transform: Optional custom transform pipeline - - Returns: - dict containing: - segmentation: Segmentation map if model produces segmentation - landmarks: Predicted landmarks if model produces localization + +def run_inference( + model: FastSurferVINN, + image_slice: np.ndarray, + AC_center: np.ndarray, + PC_center: np.ndarray, + voxel_size: float, + device: torch.device | None = None, + transform: transforms.Transform | None = None +) -> dict[str, np.ndarray]: + """Run inference on a single image slice. + + Parameters + ---------- + model : FastSurferVINN + Trained model + image_slice : np.ndarray + Input image as numpy array + AC_center : np.ndarray + Anterior commissure coordinates + PC_center : np.ndarray + Posterior commissure coordinates + voxel_size : float + Voxel size in mm + device : torch.device or None, optional + Device to run inference on, by default None. + If None, uses the device of the model. + transform : transforms.Transform or None, optional + Custom transform pipeline, by default None + + Returns + ------- + dict[str, np.ndarray] + Dictionary containing: + - segmentation : Binary segmentation map + - landmarks : Predicted landmark coordinates """ orig_shape = image_slice.shape @@ -95,7 +121,28 @@ def run_inference(model, image_slice, AC_center, PC_center, voxel_size, device=N #device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = next(model.parameters()).device - def crop_around_acpc(img, ac, pc, vox_size): + def crop_around_acpc(img: np.ndarray, + ac: np.ndarray, + pc: np.ndarray, + vox_size: float) -> dict[str, np.ndarray]: + """Crop image around AC-PC points. + + Parameters + ---------- + img : np.ndarray + Input image + ac : np.ndarray + Anterior commissure coordinates + pc : np.ndarray + Posterior commissure coordinates + vox_size : float + Voxel size in mm + + Returns + ------- + dict[str, np.ndarray] + Dictionary containing cropped image and metadata + """ return CropAroundACPC(keys=['image'], padding_mm=35, random_translate=0)( {'image': img, 'AC_center': ac, 'PC_center': pc, 'res': vox_size} ) @@ -166,8 +213,6 @@ def crop_around_acpc(img, ac, pc, vox_size): outputs_soft.transpose(0,2,3,1), ) -# TODO: load validation data and run inference on it to confirm correct processing - def load_validation_data(path): import pandas as pd @@ -212,12 +257,39 @@ def one_hot_to_label(one_hot, label_ids=None): return label -# TODO: add heuristic that removes islands that are far away - - - -def run_inference_on_slice(model, test_slice, AC_center, PC_center, voxel_size): +def run_inference_on_slice(model: FastSurferVINN, + test_slice: np.ndarray, + AC_center: np.ndarray, + PC_center: np.ndarray, + voxel_size: float) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Run inference on a single slice. + + Parameters + ---------- + model : FastSurferVINN + Trained model for inference + test_slice : np.ndarray + Input image slice + AC_center : np.ndarray + Anterior commissure coordinates + PC_center : np.ndarray + Posterior commissure coordinates + voxel_size : float + Voxel size in mm + + Returns + ------- + results: np.ndarray + Label map after one-hot conversion + inputs: np.ndarray + Preprocessed input image + outputs_avg: np.ndarray + Averaged model outputs + outputs_soft: np.ndarray + Softlabel outputs (non-discrete) + + """ # add zero in front of AC_center and PC_center AC_center = np.concatenate([np.zeros(1), AC_center]) PC_center = np.concatenate([np.zeros(1), PC_center]) @@ -226,54 +298,3 @@ def run_inference_on_slice(model, test_slice, AC_center, PC_center, voxel_size): results = one_hot_to_label(results) return results, inputs, outputs_avg, outputs_soft - - - -def remove_small_clusters(label_data, min_cluster_size=100): - """ - Removes small clusters of connected components from a label image. - - Args: - label_data: numpy array containing the label data - min_cluster_size: minimum size of clusters to keep (default: 100) - - Returns: - cleaned_label: numpy array with small clusters removed - """ - from scipy.ndimage import label as ndlabel - - - list_of_cleaned_labels = [] - - for label_id in range(label_data.shape[1]-1): - - # Create a binary mask of the label - binary_mask = label_data[:,label_id+1] > 0 - - - # Label the connected components - labeled_array, num_features = ndlabel(binary_mask) - - # Create a mask for small clusters - small_clusters_mask = np.zeros_like(binary_mask, dtype=bool) - for i in range(1, num_features + 1): - small_cluster = (labeled_array == i) - if np.sum(small_cluster) < min_cluster_size: - small_clusters_mask |= small_cluster - - # Remove small clusters from the original label - cleaned_label = label_data[:,label_id+1].copy() - cleaned_label[small_clusters_mask] = 0 - list_of_cleaned_labels.append(cleaned_label) - - - # plot binary mask - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots(2,len(binary_mask)) - # for i in range(len(binary_mask)): - # ax[0,i].imshow(binary_mask[i]) - # ax[1,i].imshow(cleaned_label[i]) - # plt.show() - - return np.stack([label_data[:,0]]+list_of_cleaned_labels, axis=1) - diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index 3fa49738..542e6f4a 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -25,13 +25,24 @@ def find_component_boundaries(labels_arr: np.ndarray, component_id: int) -> np.ndarray: """Find boundary voxels of a connected component. - - Args: - labels_arr (np.ndarray): Labeled array from connected components analysis - component_id (int): ID of the component to find boundaries for - - Returns: - np.ndarray: Array of boundary coordinates (N, 3) + + Parameters + ---------- + labels_arr : np.ndarray + Labeled array from connected components analysis + component_id : int + ID of the component to find boundaries for + + Returns + ------- + np.ndarray + Array of shape (N, 3) containing boundary coordinates + + Notes + ----- + Uses 6-connectivity (face neighbors only) to determine boundaries. + Boundary voxels are those that are part of the component but have + at least one non-component neighbor. """ component_mask = labels_arr == component_id @@ -47,17 +58,34 @@ def find_component_boundaries(labels_arr: np.ndarray, component_id: int) -> np.n return np.array(np.where(boundary)).T -def find_minimal_connection_path(boundary1: np.ndarray, boundary2: np.ndarray, - max_distance: float = 3.0) -> tuple[np.ndarray, np.ndarray] | None: +def find_minimal_connection_path( + boundary1: np.ndarray, + boundary2: np.ndarray, + max_distance: float = 3.0 +) -> tuple[np.ndarray, np.ndarray] | None: """Find the minimal connection path between two component boundaries. - - Args: - boundary1 (np.ndarray): Boundary coordinates of first component (N1, 3) - boundary2 (np.ndarray): Boundary coordinates of second component (N2, 3) - max_distance (float): Maximum distance to consider for connection - - Returns: - tuple | None: (point1, point2) coordinates of closest points if within max_distance, None otherwise + + Parameters + ---------- + boundary1 : np.ndarray + Boundary coordinates of first component, shape (N1, 3) + boundary2 : np.ndarray + Boundary coordinates of second component, shape (N2, 3) + max_distance : float, optional + Maximum distance to consider for connection, by default 3.0 + + Returns + ------- + tuple[np.ndarray, np.ndarray] or None + If a valid connection is found: + - point1 : Coordinates on first boundary + - point2 : Coordinates on second boundary + None if no connection within max_distance is found + + Notes + ----- + Uses Euclidean distance to find the closest pair of points + between the two boundaries. """ if len(boundary1) == 0 or len(boundary2) == 0: return None @@ -78,14 +106,27 @@ def find_minimal_connection_path(boundary1: np.ndarray, boundary2: np.ndarray, def create_connection_line(point1: np.ndarray, point2: np.ndarray) -> list[tuple[int, int, int]]: - """Create a line of voxels connecting two points using simplified 3D line algorithm. - - Args: - point1 (np.ndarray): Starting point coordinates (3,) - point2 (np.ndarray): Ending point coordinates (3,) - - Returns: - list: List of (x, y, z) coordinates forming the connection line + """Create a line of voxels connecting two points. + + Uses a simplified 3D line algorithm to create a sequence of voxels + that form a continuous path between the two points. + + Parameters + ---------- + point1 : np.ndarray + Starting point coordinates, shape (3,) + point2 : np.ndarray + Ending point coordinates, shape (3,) + + Returns + ------- + list[tuple[int, int, int]] + List of (x, y, z) coordinates forming the connection line + + Notes + ----- + The line is created by interpolating between the points using + the maximum distance in any dimension as the number of steps. """ x1, y1, z1 = map(int, point1) x2, y2, z2 = map(int, point2) @@ -119,16 +160,29 @@ def create_connection_line(point1: np.ndarray, point2: np.ndarray) -> list[tuple def connect_nearby_components(seg_arr: np.ndarray, max_connection_distance: float = 3.0) -> np.ndarray: """Connect nearby disconnected components that should be connected. - + This function identifies disconnected components in the segmentation and creates minimal connections between components that are close to each other. - - Args: - seg_arr (np.ndarray): Input binary segmentation array - max_connection_distance (float): Maximum distance to connect components - - Returns: - np.ndarray: Segmentation array with minimal connections added + + Parameters + ---------- + seg_arr : np.ndarray + Input binary segmentation array + max_connection_distance : float, optional + Maximum distance to connect components, by default 3.0 + + Returns + ------- + np.ndarray + Segmentation array with minimal connections added between nearby components + + Notes + ----- + The function: + 1. Identifies connected components in the input segmentation + 2. Finds boundaries between components + 3. Creates minimal connections between nearby components + 4. Returns the modified segmentation with added connections """ # Create a copy to modify @@ -224,24 +278,44 @@ def connect_nearby_components(seg_arr: np.ndarray, max_connection_distance: floa return connected_seg -def get_cc_volume_voxel(desired_width_mm: int, cc_mask: np.ndarray, voxel_size: tuple[float, float, float]) -> float: +def get_cc_volume_voxel( + desired_width_mm: int, + cc_mask: np.ndarray, + voxel_size: tuple[float, float, float] +) -> float: """Calculate the volume of the corpus callosum in cubic millimeters. - + This function calculates the volume of the corpus callosum (CC) in cubic millimeters. If the CC width is larger than desired_width_mm, the voxels on the edges are calculated as partial volumes to achieve the desired width. - - Args: - desired_width_mm (int): Desired width of the CC in millimeters - cc_mask (np.ndarray): Binary mask of the corpus callosum - voxel_size (tuple[float, float, float]): Voxel size in millimeters (x, y, z) - - Returns: - float: Volume of the CC in cubic millimeters - - Raises: - ValueError: If CC width is smaller than desired width - AssertionError: If CC mask doesn't have odd number of voxels in x dimension + + Parameters + ---------- + desired_width_mm : int + Desired width of the CC in millimeters + cc_mask : np.ndarray + Binary mask of the corpus callosum + voxel_size : tuple[float, float, float] + Voxel size in millimeters (x, y, z) + + Returns + ------- + float + Volume of the CC in cubic millimeters + + Raises + ------ + ValueError + If CC width is smaller than desired width + AssertionError + If CC mask doesn't have odd number of voxels in x dimension + + Notes + ----- + The function assumes LIA orientation where: + - x dimension corresponds to Left/Right + - y dimension corresponds to Inferior/Superior + - z dimension corresponds to Anterior/Posterior """ assert cc_mask.shape[0] % 2 == 1, "CC mask must have odd number of voxels in x dimension" @@ -276,24 +350,34 @@ def get_cc_volume_voxel(desired_width_mm: int, cc_mask: np.ndarray, voxel_size: else: raise ValueError(f"Width of CC segmentation is smaller than desired width: {width_mm} < {desired_width_mm}") -def get_cc_volume_contour(desired_width_mm: int, cc_contours: list[np.ndarray], - voxel_size: tuple[float, float, float]) -> float: - """Calculate the volume of the corpus callosum in cubic millimeters using Simpson's rule. - - This function calculates the volume of the corpus callosum (CC) in cubic millimeters using Simpson's rule. - If the CC width is larger than desired_width_mm, the voxels on the edges are calculated as - partial volumes to achieve the desired width. - - Args: - desired_width_mm (int): Desired width of the CC in millimeters - cc_contours (list[np.ndarray]): List of CC contours for each slice in the left-right direction - voxel_size (tuple[float, float, float]): Voxel size in millimeters (x, y, z) - - Returns: - float: Volume of the CC in cubic millimeters - - Raises: - ValueError: If CC width is smaller than desired width or insufficient contours for Simpson's rule +def get_cc_volume_contour(cc_contours: list[np.ndarray], + voxel_size: tuple[float, float, float]) -> float: + """Calculate the volume of the corpus callosum using Simpson's rule. + + Parameters + ---------- + desired_width_mm : int + Desired width of the CC in millimeters + cc_contours : list[np.ndarray] + List of CC contours for each slice in the left-right direction + voxel_size : tuple[float, float, float] + Voxel size in millimeters (x, y, z) + + Returns + ------- + float + Volume of the CC in cubic millimeters + + Raises + ------ + ValueError + If CC width is smaller than desired width or insufficient contours for Simpson's rule + + Notes + ----- + This function calculates the volume of the corpus callosum (CC) in cubic millimeters + using Simpson's rule. If the CC width is larger than desired_width_mm, the voxels on + the edges are calculated as partial volumes to achieve the desired width. """ if len(cc_contours) < 3: raise ValueError("Need at least 3 contours for Simpson's rule integration") @@ -342,22 +426,30 @@ def get_cc_volume_contour(desired_width_mm: int, cc_contours: list[np.ndarray], return integrate.simpson(areas, x=measurement_points) -def get_largest_cc(seg_arr: np.ndarray, max_connection_distance: float = 3.0) -> tuple[np.ndarray, np.ndarray]: - """Get largest connected component from a binary segmentation array with minimal connections. - - This function takes a binary segmentation array, attempts to connect nearby disconnected - components that should be connected, then finds the largest connected component. - It first tries to establish minimal connections between close components before - falling back to dilation if no connections are made. - - Args: - seg_arr (np.ndarray): Input binary segmentation array - max_connection_distance (float): Maximum distance to connect components (default: 3.0) - - Returns: - tuple: A tuple containing: - - clean_seg (np.ndarray): Segmentation array with only the largest connected component - - largest_cc (np.ndarray): Binary mask of the largest connected component +def get_largest_cc( + seg_arr: np.ndarray, + max_connection_distance: float = 3.0 +) -> np.ndarray: + """Get largest connected component from a binary segmentation array. + + Parameters + ---------- + seg_arr : np.ndarray + Input binary segmentation array + max_connection_distance : float, optional + Maximum distance to connect components, by default 3.0 + + Returns + ------- + np.ndarray + Binary mask of the largest connected component + + Notes + ----- + The function first attempts to connect nearby disconnected components + that should be connected, then finds the largest connected component. + It uses minimal connections between close components before falling + back to dilation if no connections are made. """ # First attempt: try to connect nearby components with minimal connections connected_seg = connect_nearby_components(seg_arr, max_connection_distance) @@ -394,23 +486,33 @@ def get_largest_cc(seg_arr: np.ndarray, max_connection_distance: float = 3.0) -> return largest_cc -def clean_cc_segmentation(seg_arr: np.ndarray, max_connection_distance: float = 3.0) -> tuple[np.ndarray, np.ndarray]: +def clean_cc_segmentation( + seg_arr: np.ndarray, + max_connection_distance: float = 3.0 +) -> tuple[np.ndarray, np.ndarray]: """Clean corpus callosum segmentation by removing non-connected components. - - This function processes a segmentation array to clean up the corpus callosum (CC) - by removing non-connected components. It first isolates the CC (label 192), - attempts to connect nearby disconnected components, then adds the fornix (label 250), - and finally removes non-connected components from the combined CC and fornix. - - Args: - seg_arr (np.ndarray): Input segmentation array with CC (192) and fornix (250) labels - max_connection_distance (float): Maximum distance to connect components (default: 3.0) - - Returns: - tuple: A tuple containing: - - clean_seg (np.ndarray): Cleaned segmentation array with only the largest - connected component of CC and fornix - - mask (np.ndarray): Binary mask of the largest connected component + + Parameters + ---------- + seg_arr : np.ndarray + Input segmentation array with CC (192) and fornix (250) labels + max_connection_distance : float, optional + Maximum distance to connect components, by default 3.0 + + Returns + ------- + tuple[np.ndarray, np.ndarray] + - clean_seg : Cleaned segmentation array with only the largest + connected component of CC and fornix + - mask : Binary mask of the largest connected component + + Notes + ----- + The function: + 1. Isolates the CC (label 192) + 2. Attempts to connect nearby disconnected components + 3. Adds the fornix (label 250) + 4. Removes non-connected components from the combined CC and fornix """ # Remove non connected components from the CC alone, with minimal connections cc_seg = np.zeros_like(seg_arr) diff --git a/CorpusCallosum/shape/cc_endpoint_heuristic.py b/CorpusCallosum/shape/cc_endpoint_heuristic.py index 8bf9b8c6..b9d18818 100644 --- a/CorpusCallosum/shape/cc_endpoint_heuristic.py +++ b/CorpusCallosum/shape/cc_endpoint_heuristic.py @@ -13,33 +13,28 @@ # limitations under the License. import lapy -import nibabel import numpy as np -import numpy.typing as npt -import pandas as pd import scipy.ndimage import skimage.measure from scipy.ndimage import label -def smooth_contour(x: npt.NDArray[np.float64], - y: npt.NDArray[np.float64], - window_size: int) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: +def smooth_contour(x: np.ndarray, y: np.ndarray, window_size: int) -> tuple[np.ndarray, np.ndarray]: """Smooth a contour using a moving average filter. Parameters ---------- - x : npt.NDArray[np.float64] - x-coordinates of the contour points - y : npt.NDArray[np.float64] - y-coordinates of the contour points + x : np.ndarray + X-coordinates of the contour points. + y : np.ndarray + Y-coordinates of the contour points. window_size : int Size of the smoothing window. Must be odd and > 2. Returns ------- - tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]] - Smoothed x and y coordinates of the contour + tuple[np.ndarray, np.ndarray] + Smoothed x and y coordinates of the contour. """ # Ensure window_size is an integer window_size = int(window_size) @@ -70,13 +65,13 @@ def smooth_contour(x: npt.NDArray[np.float64], return x_smoothed, y_smoothed -def connect_diagonally_connected_components(cc_mask: npt.NDArray[np.bool_]) -> None: +def connect_diagonally_connected_components(cc_mask: np.ndarray) -> None: """Connect diagonally connected components in the CC mask. Parameters ---------- - cc_mask : npt.NDArray[np.bool_] - Binary mask of the corpus callosum + cc_mask : np.ndarray + Binary mask of the corpus callosum. Notes ----- @@ -134,21 +129,20 @@ def connect_diagonally_connected_components(cc_mask: npt.NDArray[np.bool_]) -> N cc_mask[connects_diagonals] = 1 -def extract_cc_contour(cc_mask: npt.NDArray[np.bool_], - contour_smoothing: int = 5) -> npt.NDArray[np.float64]: +def extract_cc_contour(cc_mask: np.ndarray, contour_smoothing: int = 5) -> np.ndarray: """Extract the contour of the CC from the mask. Parameters ---------- - cc_mask : npt.NDArray[np.bool_] - Binary mask of the corpus callosum + cc_mask : np.ndarray + Binary mask of the corpus callosum. contour_smoothing : int, optional - Window size for contour smoothing, by default 5 + Window size for contour smoothing, by default 5. Returns ------- - npt.NDArray[np.float64] - Array of shape (2, N) containing x,y coordinates of the contour points + np.ndarray + Array of shape (2, N) containing x,y coordinates of the contour points. """ # cc_mask_orig = cc_mask cc_mask = cc_mask.copy() @@ -169,42 +163,42 @@ def extract_cc_contour(cc_mask: npt.NDArray[np.bool_], return contour -def get_endpoints(cc_mask: npt.NDArray[np.bool_], - AC_2d: npt.NDArray[np.float64], - PC_2d: npt.NDArray[np.float64], - resolution: float, - return_coordinates: bool = True, - contour_smoothing: int = 5) -> ( - tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]] | - tuple[npt.NDArray[np.float64], int, int]): +def get_endpoints( + cc_mask: np.ndarray, + AC_2d: np.ndarray, + PC_2d: np.ndarray, + resolution: float, + return_coordinates: bool = True, + contour_smoothing: int = 5 +) -> tuple[np.ndarray, np.ndarray, np.ndarray] | tuple[np.ndarray, int, int]: """Determine endpoints of CC by finding points closest to AC and PC. Parameters ---------- - cc_mask : npt.NDArray[np.bool_] - Binary mask of the corpus callosum - AC_2d : npt.NDArray[np.float64] - 2D coordinates of the anterior commissure - PC_2d : npt.NDArray[np.float64] - 2D coordinates of the posterior commissure + cc_mask : np.ndarray + Binary mask of the corpus callosum. + AC_2d : np.ndarray + 2D coordinates of the anterior commissure. + PC_2d : np.ndarray + 2D coordinates of the posterior commissure. resolution : float - Image resolution in mm + Image resolution in mm. return_coordinates : bool, optional - If True, return endpoint coordinates, otherwise return indices, by default True + If True, return endpoint coordinates, otherwise return indices, by default True. contour_smoothing : int, optional - Window size for contour smoothing, by default 5 + Window size for contour smoothing, by default 5. Returns ------- - tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]] | tuple[npt.NDArray[np.float64], int, int] + tuple[np.ndarray, np.ndarray, np.ndarray] | tuple[np.ndarray, int, int] If return_coordinates is True: - (contour, anterior_point, posterior_point) + (contour, anterior_point, posterior_point). If return_coordinates is False: - (contour, anterior_index, posterior_index) + (contour, anterior_index, posterior_index). Notes ----- - Expects LIA orientation + Expects LIA orientation. """ image_size = cc_mask.shape diff --git a/CorpusCallosum/shape/cc_mesh.py b/CorpusCallosum/shape/cc_mesh.py index 41fb9149..0824a6b2 100644 --- a/CorpusCallosum/shape/cc_mesh.py +++ b/CorpusCallosum/shape/cc_mesh.py @@ -44,23 +44,35 @@ class CC_Mesh(lapy.TriaMesh): corpus callosum, with optional thickness measurements at various points along these contours. - Attributes: - contours (list): List of numpy arrays containing 2D contour points for each slice. - thickness_values (list): List of thickness measurements for each contour point. - start_end_idx (list): List of tuples containing start and end indices for each contour. - ac_coords (numpy.ndarray): Coordinates of the anterior commissure. - pc_coords (numpy.ndarray): Coordinates of the posterior commissure. - resolution (float): Spatial resolution of the mesh. - v (numpy.ndarray): Vertex coordinates of the mesh. - t (numpy.ndarray): Triangle indices of the mesh. - original_thickness_vertices (list): List of vertex indices where thickness was originally measured. + Attributes + ---------- + contours : list[np.ndarray] + List of numpy arrays containing 2D contour points for each slice. + thickness_values : list[np.ndarray] + List of thickness measurements for each contour point. + start_end_idx : list[tuple[int, int]] + List of tuples containing start and end indices for each contour. + ac_coords : np.ndarray + Coordinates of the anterior commissure. + pc_coords : np.ndarray + Coordinates of the posterior commissure. + resolution : float + Spatial resolution of the mesh. + v : np.ndarray + Vertex coordinates of the mesh. + t : np.ndarray + Triangle indices of the mesh. + original_thickness_vertices : list[np.ndarray] + List of vertex indices where thickness was originally measured. """ def __init__(self, num_slices): """Initialize a CC_Mesh object. - Args: - num_slices (int): Number of slices in the corpus callosum mesh. + Parameters + ---------- + num_slices : int + Number of slices in the corpus callosum mesh """ self.contours = [None] * num_slices self.thickness_values = [None] * num_slices @@ -81,16 +93,17 @@ def add_contour( ): """Add a contour and its associated thickness values for a specific slice. - Args: - slice_idx (int): - Index of the slice where the contour should be added. - contour (numpy.ndarray): - Array of shape (N, 2) containing 2D contour points. - thickness_values (numpy.ndarray): - Array of thickness measurements for each contour point. - start_end_idx (tuple[int, int], optional): - Tuple containing start and end indices for the contour. - If None, defaults to (0, len(contour)//2). + Parameters + ---------- + slice_idx : int + Index of the slice where the contour should be added. + contour : np.ndarray + Array of shape (N, 2) containing 2D contour points. + thickness_values : np.ndarray + Array of thickness measurements for each contour point. + start_end_idx : tuple[int, int], optional + Tuple containing start and end indices for the contour. + If None, defaults to (0, len(contour)//2). """ self.contours[slice_idx] = contour self.thickness_values[slice_idx] = thickness_values @@ -105,9 +118,12 @@ def add_contour( def set_acpc_coords(self, ac_coords: np.ndarray, pc_coords: np.ndarray): """Set the coordinates of the anterior and posterior commissure. - Args: - ac_coords (numpy.ndarray): 3D coordinates of the anterior commissure. - pc_coords (numpy.ndarray): 3D coordinates of the posterior commissure. + Parameters + ---------- + ac_coords : np.ndarray + 3D coordinates of the anterior commissure. + pc_coords : np.ndarray + 3D coordinates of the posterior commissure. """ self.ac_coords = ac_coords self.pc_coords = pc_coords @@ -115,8 +131,10 @@ def set_acpc_coords(self, ac_coords: np.ndarray, pc_coords: np.ndarray): def set_resolution(self, resolution: float): """Set the spatial resolution of the mesh. - Args: - resolution (float): Spatial resolution in millimeters. + Parameters + ---------- + resolution : float + Spatial resolution in millimeters. """ self.resolution = resolution @@ -135,37 +153,37 @@ def plot_mesh( """Plot the mesh using Plotly for better performance and interactivity. Creates an interactive 3D visualization of the mesh with optional features like - thickness overlay, contour display, and grid visualization. The plot can be saved - to an HTML file or displayed in a web browser. - - Args: - output_path (str, optional): - Path to save the plot. If None, displays the plot interactively. - colormap (str, optional): - Which colormap to use. Options are: - - "red_to_blue": Red -> Orange -> Grey -> Light Blue -> Blue - - "red_to_yellow": Red -> Yellow -> Light Blue -> Blue - - "yellow_to_red": Yellow -> Light Blue -> Blue -> Red - - "blue_to_red": Blue -> Light Blue -> Grey -> Orange -> Red - - Defaults to "red_to_yellow". - thickness_overlay (bool, optional): - Whether to overlay thickness values on the mesh. - Defaults to True. - show_contours (bool, optional): - Whether to show the contours. Defaults to False. - show_grid (bool, optional): - Whether to show the grid. Defaults to False. - color_range (tuple[float, float], optional): - Optional tuple of (min, max) to set fixed - color range. Defaults to None. - show_mesh_edges (bool, optional): - Whether to show the mesh edges. Defaults to False. - legend (str, optional): - Legend text for the colorbar. Defaults to "". - threshold (tuple[float, float], optional): - Values between these thresholds will be shown in grey. - Defaults to (-0.2, 0.2). + thickness overlay, contour display, and grid visualization. + + Parameters + ---------- + output_path : str, optional + Path to save the plot. If None, displays the plot interactively. + colormap : str, optional + Which colormap to use, by default "red_to_yellow". + Options: + - "red_to_blue": Red -> Orange -> Grey -> Light Blue -> Blue + - "red_to_yellow": Red -> Yellow -> Light Blue -> Blue + - "yellow_to_red": Yellow -> Light Blue -> Blue -> Red + - "blue_to_red": Blue -> Light Blue -> Grey -> Orange -> Red + thickness_overlay : bool, optional + Whether to overlay thickness values on the mesh, by default True. + show_contours : bool, optional + Whether to show the contours, by default False. + show_grid : bool, optional + Whether to show the grid, by default False. + color_range : tuple[float, float], optional + Fixed range (min, max) for the colorbar, by default None. + show_mesh_edges : bool, optional + Whether to show the mesh edges, by default False. + legend : str, optional + Legend text for the colorbar, by default "". + threshold : tuple[float, float], optional + Values between these thresholds will be shown in grey, by default None. + + Notes + ----- + The plot can be saved to an HTML file or displayed in a web browser. """ assert self.v is not None and self.t is not None, "Mesh has not been created yet" @@ -430,32 +448,50 @@ def plot_mesh( fig.write_html(temp_path) webbrowser.open("file://" + temp_path) - def get_contour_edge_lengths(self, contour_idx): + def get_contour_edge_lengths(self, contour_idx: int) -> np.ndarray: """Get the lengths of the edges of a contour. - Args: - contour_idx (int): Index of the contour to get the edge lengths for. + Parameters + ---------- + contour_idx : int + Index of the contour to get the edge lengths for. + + Returns + ------- + np.ndarray + Array of edge lengths for the contour. - Returns: - numpy.ndarray: Array of edge lengths for the contour. + Notes + ----- + Edge lengths are calculated as Euclidean distances between consecutive points + in the contour. """ edges = np.diff(self.contours[contour_idx], axis=0) return np.sqrt(np.sum(edges**2, axis=1)) @staticmethod - def make_triangles_between_contours(contour1, contour2): - """Creates a triangular mesh between two contours using a robust method. + def make_triangles_between_contours(contour1: np.ndarray, contour2: np.ndarray) -> np.ndarray: + """Create a triangular mesh between two contours using a robust method. - This method creates triangles that connect two contours by matching points between them. - It starts from the closest point on contour2 to the first point of contour1 and creates - triangles by connecting corresponding points. - - Args: - contour1 (numpy.ndarray): First contour points of shape (N, 2). - contour2 (numpy.ndarray): Second contour points of shape (M, 2). + Parameters + ---------- + contour1 : np.ndarray + First contour points of shape (N, 2). + contour2 : np.ndarray + Second contour points of shape (M, 2). - Returns: - numpy.ndarray: Array of triangle indices of shape (K, 3) where K is the number of triangles. + Returns + ------- + np.ndarray + Array of triangle indices of shape (K, 3) where K is the number of triangles. + + Notes + ----- + The function: + 1. Finds closest point on contour2 to first point of contour1 + 2. Creates triangles by connecting corresponding points + 3. Handles contours with different numbers of points + 4. Creates two triangles to form a quad between each pair of points """ start_idx_c1 = 0 # get closest point on contour2 to contour1[0] @@ -480,8 +516,40 @@ def make_triangles_between_contours(contour1, contour2): return np.array(triangles) - def _create_levelpaths(self, contour_idx, points, trias, num_points=None): - # # compute poisson + def _create_levelpaths( + self, + contour_idx: int, + points: np.ndarray, + trias: np.ndarray, + num_points: int | None = None + ) -> tuple[list[np.ndarray], list[float]]: + """Create level paths for thickness measurements. + + Parameters + ---------- + contour_idx : int + Index of the contour to process + points : np.ndarray + Array of shape (N, 2) containing mesh points + trias : np.ndarray + Array of shape (M, 3) containing triangle indices + num_points : int or None, optional + Number of points to sample along the midline, by default None + + Returns + ------- + tuple[list[np.ndarray], list[float]] + - levelpaths : List of arrays containing level path coordinates + - thickness_values : List of thickness values for each level path + + Notes + ----- + The function: + 1. Creates a triangular mesh from the points + 2. Finds boundary points and endpoints + 3. Solves Poisson equation for level sets + 4. Extracts level paths and interpolates thickness values + """ with HiddenPrints(): cc_tria = lapy.TriaMesh(points, trias) # extract boundary curve @@ -541,7 +609,38 @@ def _create_levelpaths(self, contour_idx, points, trias, num_points=None): return levelpaths, thickness_values - def _create_cap(self, points, trias, contour_idx): + def _create_cap( + self, + points: np.ndarray, + trias: np.ndarray, + contour_idx: int + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Create a cap mesh for one end of the corpus callosum. + + Parameters + ---------- + points : np.ndarray + Array of shape (N, 2) containing mesh points + trias : np.ndarray + Array of shape (M, 3) containing triangle indices + contour_idx : int + Index of the contour to create cap for + + Returns + ------- + tuple[np.ndarray, np.ndarray, np.ndarray] + - level_vertices : Array of vertices for the cap mesh + - level_faces : Array of face indices for the cap mesh + - level_colors : Array of thickness values for each vertex + + Notes + ----- + The function: + 1. Creates level paths using _create_levelpaths + 2. Resamples level paths to fixed number of points + 3. Creates triangles between consecutive level paths + 4. Smooths thickness values for visualization + """ levelpaths, thickness_values = self._create_levelpaths(contour_idx, points, trias) # Create mesh from level paths @@ -601,17 +700,34 @@ def _create_cap(self, points, trias, contour_idx): return level_vertices, level_faces, level_colors - def create_mesh(self, lr_center: float = 0, closed: bool = False, smooth: int = 0): - """Creates a surface mesh by triangulating between consecutive contours. + def create_mesh(self, lr_center: float = 0, closed: bool = False, smooth: int = 0) -> None: + """Create a surface mesh by triangulating between consecutive contours. - This method constructs a 3D mesh from the stored contours by creating triangles between - adjacent slices. It can optionally create a closed mesh by adding caps at the ends and - apply smoothing. - - Args: - lr_center (float, optional): Center position in the left-right axis. Defaults to 0. - closed (bool, optional): Whether to create a closed mesh by adding caps. Defaults to False. - smooth (int, optional): Number of smoothing iterations to apply. Defaults to 0. + Parameters + ---------- + lr_center : float, optional + Center position in the left-right axis, by default 0. + closed : bool, optional + Whether to create a closed mesh by adding caps, by default False. + smooth : int, optional + Number of smoothing iterations to apply, by default 0. + + Raises + ------ + Warning + If no valid contours are found. + + Notes + ----- + The function: + 1. Filters out None contours. + 2. Calculates z-coordinates for each slice. + 3. Creates triangles between adjacent contours. + 4. Optionally: + - Creates caps at both ends. + - Applies smoothing. + - Colors caps based on thickness values. + """ # Filter out None contours and get their indices valid_contours = [(i, c) for i, c in enumerate(self.contours) if c is not None] @@ -683,9 +799,22 @@ def create_mesh(self, lr_center: float = 0, closed: bool = False, smooth: int = [self.mesh_vertex_colors, left_side_colors, right_side_colors], ) - def fill_thickness_values(self): - """ - Interpolate missing thickness values on the contours by weighted average of nearest known thickness values. + def fill_thickness_values(self) -> None: + """Interpolate missing thickness values using weighted averaging. + + Notes + ----- + The function: + 1. Processes each contour with missing thickness values. + 2. For each missing value: + - Finds two closest points with known thickness. + - Calculates distances along contour. + - Computes weighted average based on inverse distance. + 3. Updates thickness values in place. + + The weights are calculated as inverse distances to ensure closer + points have more influence on the interpolated value. + """ # For each contour with missing thickness values @@ -726,27 +855,45 @@ def fill_thickness_values(self): self.thickness_values[i] = thickness - def smooth_thickness_values(self, iterations: int = 1): - """ - Smooth the thickness values using a Gaussian filter - """ + def smooth_thickness_values(self, iterations: int = 1) -> None: + """Smooth the thickness values using a Gaussian filter. + + Parameters + ---------- + iterations : int, optional + Number of smoothing iterations, by default 1. + Notes + ----- + Applies Gaussian smoothing with sigma=5 to thickness values + for each slice that has measurements. + """ for i in range(len(self.thickness_values)): if self.thickness_values[i] is not None: self.thickness_values[i] = gaussian_filter1d(self.thickness_values[i], sigma=5) - def plot_contour(self, slice_idx: int, output_path: str): + def plot_contour(self, slice_idx: int, output_path: str) -> None: """Plot a single contour with thickness values. - Creates a 2D visualization of a specific contour slice with points colored according - to their thickness values. The plot is saved to the specified output path. - - Args: - slice_idx (int): Index of the slice to plot. - output_path (str): Path where to save the plot. - - Raises: - ValueError: If the contour for the specified slice is not set. + Parameters + ---------- + slice_idx : int + Index of the slice to plot. + output_path : str + Path where to save the plot. + + Raises + ------ + ValueError + If the contour for the specified slice is not set. + + Notes + ----- + Creates a 2D visualization with: + - Points colored by thickness values. + - Gray points for missing thickness values. + - Connected contour line. + - Grid, labels, and legend. """ self.__make_parent_folder(output_path) @@ -784,43 +931,57 @@ def plot_contour(self, slice_idx: int, output_path: str): plt.tight_layout() plt.savefig(output_path, dpi=300) - def smooth_contour(self, contour_idx, window_size=5): - """ - Smooth a contour using a moving average filter. + def smooth_contour(self, contour_idx: int, window_size: int = 5) -> None: + """Smooth a contour using a moving average filter. Parameters ---------- - contour : tuple of arrays - The contour coordinates (x, y). - window_size : int - Size of the smoothing window. - - Returns - ------- - tuple of arrays - The smoothed contour coordinates (x, y). + contour_idx : int + Index of the contour to smooth. + window_size : int, optional + Size of the smoothing window, by default 5. + + Notes + ----- + Uses smooth_contour from cc_endpoint_heuristic module to: + 1. Extract x and y coordinates. + 2. Apply moving average smoothing. + 3. Update contour with smoothed coordinates. """ x, y = self.contours[contour_idx].T - x, y = smooth_contour(x, y, window_size) - self.contours[contour_idx] = np.array([x, y]).T - def plot_cc_contour_with_levelsets(self, contour_idx=0, levelpaths=None, title=None, save_path=None, colorbar=True): + def plot_cc_contour_with_levelsets( + self, + contour_idx: int = 0, + levelpaths: list | None = None, + title: str | None = None, + save_path: str | None = None, + colorbar: bool = True, + ) -> matplotlib.figure.Figure: """Plot a contour with levelset visualization. Creates a visualization of a contour with interpolated levelsets, useful for analyzing the thickness distribution across the corpus callosum. - Args: - contour_idx (int, optional): Index of the contour to plot. Defaults to 0. - levelpaths (list, optional): List of levelset paths. If None, uses stored levelpaths. - title (str, optional): Title for the plot. Defaults to None. - save_path (str, optional): Path to save the plot. If None, displays interactively. - colorbar (bool, optional): Whether to show the colorbar. Defaults to True. + Parameters + ---------- + contour_idx : int, optional + Index of the contour to plot, by default 0. + levelpaths : list or None, optional + List of levelset paths. If None, uses stored levelpaths, by default None. + title : str or None, optional + Title for the plot, by default None. + save_path : str or None, optional + Path to save the plot. If None, displays interactively, by default None. + colorbar : bool, optional + Whether to show the colorbar, by default True. - Returns: - matplotlib.figure.Figure: The created figure object. + Returns + ------- + matplotlib.figure.Figure + The created figure object. """ plot_values = np.array(self.thickness_values[contour_idx][~np.isnan(self.thickness_values[contour_idx])][:100])[ @@ -999,13 +1160,25 @@ def plot_cc_contour_with_levelsets(self, contour_idx=0, levelpaths=None, title=N plt.show() return fig - def set_mesh(self, vertices, faces, thickness_values=None): + def set_mesh(self, + vertices: list | np.ndarray, + faces: list | np.ndarray, + thickness_values: list | np.ndarray | None = None) -> None: """Set the mesh vertices, faces, and optional thickness values. - Args: - vertices (list or numpy.ndarray): List of vertex coordinates or array of shape (N, 3). - faces (list or numpy.ndarray): List of face indices or array of shape (M, 3). - thickness_values (list or numpy.ndarray, optional): Thickness values for each vertex. + Parameters + ---------- + vertices : list or numpy.ndarray + List of vertex coordinates or array of shape (N, 3). + faces : list or numpy.ndarray + List of face indices or array of shape (M, 3). + thickness_values : list or numpy.ndarray, optional + Thickness values for each vertex. + + Returns + ------- + None + The function does not return anything. """ # Handle case when there are no faces (single contour) if not faces: @@ -1028,9 +1201,23 @@ def set_mesh(self, vertices, faces, thickness_values=None): self.mesh_vertex_colors = np.array([]) @staticmethod - def __create_cc_viewmat(): - """ - Create the view matrix for a nice view of the corpus callosum. + def __create_cc_viewmat() -> pyrr.Matrix44: + """Create the view matrix for a nice view of the corpus callosum. + + Returns + ------- + pyrr.Matrix44 + 4x4 view matrix that provides a standard view of the corpus callosum + + Notes + ----- + The function: + 1. Creates a base view matrix looking from the left with top up + 2. Applies a series of rotations: + - -10 degrees around x-axis + - 35 degrees around y-axis + - -8 degrees around z-axis + 3. Adds a small translation for better centering """ viewLeft = np.array([[0, 0, -1, 0], [-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) # left w top up // right transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) @@ -1050,23 +1237,40 @@ def __create_cc_viewmat(): return viewmat - def snap_cc_picture(self, output_path: str, fssurf_file: str | None = None, overlay_file: str | None = None): + def snap_cc_picture( + self, + output_path: str, + fssurf_file: str | None = None, + overlay_file: str | None = None + ) -> None: """Snap a picture of the corpus callosum mesh. - Takes a snapshot of the mesh from a predefined viewpoint, with optional thickness - overlay. The image is saved to the specified output path. - - Args: - output_path (str): - Path where to save the snapshot image. - fssurf_file (str | None): Path to a FreeSurfer surface file to use for the snapshot. If None, - the mesh is saved to a temporary file. Defaults to None. - overlay_file (str | None): Path to a FreeSurfer overlay file to use for the snapshot. If None, - the mesh is saved to a temporary file. Defaults to None. + Parameters + ---------- + output_path : str + Path where to save the snapshot image. + fssurf_file : str or None, optional + Path to a FreeSurfer surface file to use for the snapshot. + If None, the mesh is saved to a temporary file, by default None. + overlay_file : str or None, optional + Path to a FreeSurfer overlay file to use for the snapshot. + If None, the mesh is saved to a temporary file, by default None. + + Raises + ------ + Warning + If the mesh has no faces and cannot create a snapshot. + + Notes + ----- + The function: + 1. Creates temporary files for mesh and overlay data if needed. + 2. Uses whippersnappy to create a snapshot with: + - Custom view matrix for standard orientation. + - Ambient lighting and colorbar settings. + - Thickness overlay if available. + 3. Cleans up temporary files after use. - Note: - This method uses a temporary file to store the mesh and overlay data during - the snapshot process. """ self.__make_parent_folder(output_path) # Skip snapshot if there are no faces @@ -1115,30 +1319,39 @@ def snap_cc_picture(self, output_path: str, fssurf_file: str | None = None, over temp_file.close() overlay_file.close() - def smooth_(self, iterations: int = 1): + def smooth_(self, iterations: int = 1) -> None: """Smooth the mesh while preserving the z-coordinates. - This method applies Laplacian smoothing to the mesh vertices while keeping - the z-coordinates unchanged to maintain the slice structure. - - Args: - iterations (int, optional): Number of smoothing iterations. Defaults to 1. + Parameters + ---------- + iterations : int, optional + Number of smoothing iterations, by default 1. + + Notes + ----- + The function: + 1. Stores original z-coordinates. + 2. Applies Laplacian smoothing to x and y coordinates. + 3. Restores original z-coordinates to maintain slice structure. """ z_values = self.v[:, 2] super().smooth_(iterations) self.v[:, 2] = z_values - def save_contours(self, output_path: str): + def save_contours(self, output_path: str) -> None: """Save the contours to a CSV file. - Saves all contours and their associated endpoint indices to a CSV file. - The file format is: - slice_idx,x,y - where each point of each contour gets its own row, with special lines indicating - the start of new contours and their endpoint indices. - - Args: - output_path (str): Path where to save the CSV file. + Parameters + ---------- + output_path : str + Path where to save the CSV file. + + Notes + ----- + The function saves contours in CSV format with: + - Header: slice_idx,x,y. + - Special lines indicating new contours with endpoint indices. + - Each point gets its own row with slice index and coordinates. """ logger.info(f"Saving contours to CSV file: {output_path}") with open(output_path, "w") as f: @@ -1154,18 +1367,26 @@ def save_contours(self, output_path: str): for point in contour: f.write(f"{slice_idx},{point[0]},{point[1]}\n") - def load_contours(self, input_path: str): + def load_contours(self, input_path: str) -> None: """Load contours from a CSV file. - Loads contours and their associated endpoint indices from a CSV file. - The file format should match that produced by save_contours: - slice_idx,x,y with special lines for endpoint indices. - - Args: - input_path (str): Path to the CSV file containing the contours. - - Note: - This method will reset any existing contours and endpoint indices. + Parameters + ---------- + input_path : str + Path to the CSV file containing the contours. + + Raises + ------ + ValueError + If the file format doesn't match expected structure. + + Notes + ----- + The function: + 1. Reads CSV file with format matching save_contours output. + 2. Processes special lines for endpoint indices. + 3. Reconstructs contours and endpoint indices for each slice. + 4. Converts lists to fixed-size arrays with None padding. """ current_points = [] self.contours = [] @@ -1202,15 +1423,20 @@ def load_contours(self, input_path: str): self.contours = self.contours + [None] * (max_slices - len(self.contours)) self.start_end_idx = self.start_end_idx + [None] * (max_slices - len(self.start_end_idx)) - def save_thickness_values(self, output_path: str): + def save_thickness_values(self, output_path: str) -> None: """Save thickness values to a CSV file. - Saves all thickness values to a CSV file in the format: - slice_idx,thickness - where each thickness value gets its own row. - - Args: - output_path (str): Path where to save the CSV file. + Parameters + ---------- + output_path : str + Path where to save the CSV file. + + Notes + ----- + The function saves thickness values in CSV format with: + - Header: slice_idx,thickness. + - Each thickness value gets its own row with slice index. + - Skips slices with no thickness values. """ logger.info(f"Saving thickness data to CSV file: {output_path}") with open(output_path, "w") as f: @@ -1222,24 +1448,36 @@ def save_thickness_values(self, output_path: str): for value in thickness: f.write(f"{slice_idx},{value}\n") - def load_thickness_values(self, input_path: str, original_thickness_vertices_path: str | None = None): + def load_thickness_values( + self, + input_path: str, + original_thickness_vertices_path: str | None = None + ) -> None: """Load thickness values from a CSV file. - Loads thickness values from a CSV file and optionally associates them with specific - vertices using a measurement points file. - - Args: - input_path (str): - Path to the CSV file containing thickness values. - original_thickness_vertices_path (str, optional): - Path to a file containing the - indices of vertices where thickness was measured. If None, assumes thickness - values correspond to all vertices in order. - - Raises: - ValueError: - If the number of thickness values doesn't match the number of - measurement points, or if the number of slices is inconsistent. + Parameters + ---------- + input_path : str + Path to the CSV file containing thickness values. + original_thickness_vertices_path : str or None, optional + Path to a file containing the indices of vertices where thickness + was measured, by default None. + + Raises + ------ + ValueError + If number of thickness values doesn't match measurement points + or if number of slices is inconsistent. + + Notes + ----- + The function: + 1. Reads thickness values from CSV file. + 2. Groups values by slice index. + 3. Optionally associates values with specific vertices. + 4. Handles both full contour and profile measurements. + + """ data = np.loadtxt(input_path, delimiter=",", skiprows=1) slice_indices = data[:, 0].astype(int) @@ -1319,17 +1557,43 @@ def load_thickness_values(self, input_path: str, original_thickness_vertices_pat self.thickness_values = new_thickness_values @staticmethod - def __make_parent_folder(filename: str): - """Make the parent folder of the given filename. + def __make_parent_folder(filename: str) -> None: + """Create the parent folder for a file if it doesn't exist. + + Parameters + ---------- + filename : str + Path to the file whose parent folder should be created. + + Notes + ----- + Creates parent directory with parents=False to avoid creating + multiple levels of directories unintentionally. """ output_folder = Path(filename).parent output_folder.mkdir(parents=False, exist_ok=True) - def to_fs_coordinates(self, vox2ras_tkr: np.ndarray, vox_size: tuple[int, int, int]): + def to_fs_coordinates( + self, + vox2ras_tkr: np.ndarray, + vox_size: tuple[float, float, float] + ) -> None: """Convert mesh coordinates to FreeSurfer coordinate system. - Transforms the mesh vertices from the original coordinate system to the - FreeSurfer coordinate system by reordering axes and applying appropriate offsets. + Parameters + ---------- + vox2ras_tkr : np.ndarray + 4x4 voxel to RAS tkr-space transformation matrix. + vox_size : tuple[float, float, float] + Voxel size in millimeters (x, y, z). + + Notes + ----- + The function: + 1. Converts coordinates from original to LSA orientation. + 2. Converts to voxel coordinates using voxel size. + 3. Centers LR coordinates and flips SI coordinates. + 4. Applies vox2ras_tkr transformation to get final coordinates. """ # to voxel coordinates @@ -1362,38 +1626,60 @@ def to_fs_coordinates(self, vox2ras_tkr: np.ndarray, vox_size: tuple[int, int, i - def write_fssurf(self, filename): + def write_fssurf(self, filename: str) -> None: """Write the mesh to a FreeSurfer surface file. - Args: - filename (str): Path where to save the FreeSurfer surface file. + Parameters + ---------- + filename : str + Path where to save the FreeSurfer surface file. + + Returns + ------- + None + Returns the result of the parent class's write_fssurf method. - Returns: - The result of the parent class's write_fssurf method. + Notes + ----- + Creates parent directory if needed before writing the file. """ self.__make_parent_folder(filename) return super().write_fssurf(filename) - def write_overlay(self, filename): + def write_overlay(self, filename: str) -> None: """Write the thickness values as a FreeSurfer overlay file. - Args: - filename (str): Path where to save the overlay file. + Parameters + ---------- + filename : str + Path where to save the overlay file. + + Returns + ------- + None + Returns the result of writing the morph data using nibabel. - Returns: - The result of writing the morph data using nibabel. + Notes + ----- + Creates parent directory if needed before writing the file. """ self.__make_parent_folder(filename) return nib.freesurfer.write_morph_data(filename, self.mesh_vertex_colors) - def save_thickness_measurement_points(self, filename): + def save_thickness_measurement_points(self, filename: str) -> None: """Write the thickness measurement points to a CSV file. - Saves the indices of vertices where thickness was measured for each slice - in CSV format: slice_idx,vertex_idx - - Args: - filename (str): Path where to save the CSV file. + Parameters + ---------- + filename : str + Path where to save the CSV file. + + Notes + ----- + The function saves measurement points in CSV format with: + - Header: slice_idx,vertex_idx. + - Each measurement point gets its own row. + - Skips slices with no measurement points. """ self.__make_parent_folder(filename) logger.info(f"Saving thickness measurement points to CSV file: {filename}") @@ -1405,15 +1691,27 @@ def save_thickness_measurement_points(self, filename): f.write(f"{slice_idx},{vertex_idx}\n") @staticmethod - def _load_thickness_measurement_points(filename): + def _load_thickness_measurement_points(filename: str) -> list[np.ndarray | None]: """Load thickness measurement points from a CSV file. - Args: - filename (str): Path to the CSV file containing measurement points. + Parameters + ---------- + filename : str + Path to the CSV file containing measurement points. - Returns: - list: List of arrays containing vertex indices for each slice where - thickness was measured. + Returns + ------- + list[np.ndarray | None] + List of arrays containing vertex indices for each slice where + thickness was measured. None for slices without measurements. + + Notes + ----- + The function: + 1. Reads CSV file with format: slice_idx,vertex_idx + 2. Groups vertex indices by slice index + 3. Creates a list with length matching max slice index + 4. Fills list with vertex indices arrays or None for missing slices """ data = np.loadtxt(filename, delimiter=",", skiprows=1) slice_indices = data[:, 0].astype(int) diff --git a/CorpusCallosum/shape/cc_metrics.py b/CorpusCallosum/shape/cc_metrics.py index f4921363..7bd62886 100644 --- a/CorpusCallosum/shape/cc_metrics.py +++ b/CorpusCallosum/shape/cc_metrics.py @@ -15,15 +15,18 @@ import numpy as np -def calculate_cc_index(cc_contour): - """ - Calculate CC index based on three perpendicular measurements. - - Args: - cc_contour: 2xN array of contour points in ACPC space - - Returns: - float: Sum of thicknesses at three measurement points +def calculate_cc_index(cc_contour: np.ndarray) -> float: + """Calculate CC index based on three perpendicular measurements. + + Parameters + ---------- + cc_contour : np.ndarray + Array of shape (2, N) containing contour points in ACPC space. + + Returns + ------- + float + Sum of thicknesses at three measurement points divided by AP length. """ # Get anterior and posterior points anterior_idx = np.argmin(cc_contour[0]) # Leftmost point @@ -40,7 +43,21 @@ def calculate_cc_index(cc_contour): # Get perpendicular direction # Get intersection points with contour for each measurement line - def get_intersections(start_point, direction): + def get_intersections(start_point: np.ndarray, direction: np.ndarray) -> np.ndarray: + """Find intersection points between a line and the contour. + + Parameters + ---------- + start_point : np.ndarray + Starting point of the line, shape (2,). + direction : np.ndarray + Direction vector of the line, shape (2,). + + Returns + ------- + np.ndarray + Array of shape (N, 2) containing intersection points. + """ # Get all points above and below the line points = cc_contour.T - start_point[None, :] dots = np.dot(points, direction) diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/cc_postprocessing.py index b7aedb9d..e542e700 100644 --- a/CorpusCallosum/shape/cc_postprocessing.py +++ b/CorpusCallosum/shape/cc_postprocessing.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import multiprocessing from pathlib import Path import numpy as np @@ -43,21 +44,34 @@ LIA_ORIENTATION[2,1] = -1 -def create_visualization(subdivision_method, result, midslices_data, output_image_path, - ac_coords, pc_coords, vox_size, title_suffix=""): - """Helper function to create visualization plots based on subdivision method. - - Args: - subdivision_method: The subdivision method being used - result: Dictionary containing processing results with split_contours and split_contours_hofer_frahm - midslices_data: Slice data for visualization - output_subdir: Directory to save visualization - ac_coords: AC coordinates - pc_coords: PC coordinates - title_suffix: Additional text to append to the title - - Returns: - Process object for background execution +def create_visualization(subdivision_method: str, result: dict, midslices_data: np.ndarray, + output_image_path: str | Path, ac_coords: np.ndarray, + pc_coords: np.ndarray, vox_size: float, title_suffix: str = "") -> multiprocessing.Process: + """Create visualization plots based on subdivision method. + + Parameters + ---------- + subdivision_method : str + The subdivision method being used. + result : dict + Dictionary containing processing results with split_contours and split_contours_hofer_frahm. + midslices_data : np.ndarray + Slice data for visualization. + output_image_path : str or Path + Path to save visualization. + ac_coords : np.ndarray + AC coordinates. + pc_coords : np.ndarray + PC coordinates. + vox_size : float + Voxel size in mm. + title_suffix : str, optional + Additional text to append to the title, by default "". + + Returns + ------- + multiprocessing.Process + Process object for background execution. """ title = f'CC Subsegmentation by {subdivision_method} {title_suffix}' @@ -84,78 +98,100 @@ def create_visualization(subdivision_method, result, midslices_data, output_imag -def create_slice_affine(temp_seg_affine, slice_idx, fsaverage_middle): +def create_slice_affine(temp_seg_affine: np.ndarray, slice_idx: int, fsaverage_middle: int) -> np.ndarray: """Create slice-specific affine transformation matrix. - - Adjusts the input affine transformation matrix for a specific slice by updating - the translation component based on the slice index and fsaverage middle reference. - - Args: - temp_seg_affine (np.ndarray): Base 4x4 affine transformation matrix - slice_idx (int): Index of the slice to transform - fsaverage_middle (int): Reference middle slice index in fsaverage space - - Returns: - np.ndarray: Modified 4x4 affine transformation matrix for the specific slice + + Parameters + ---------- + temp_seg_affine : np.ndarray + Base 4x4 affine transformation matrix. + slice_idx : int + Index of the slice to transform. + fsaverage_middle : int + Reference middle slice index in fsaverage space. + + Returns + ------- + np.ndarray + Modified 4x4 affine transformation matrix for the specific slice. """ slice_affine = temp_seg_affine.copy() slice_affine[0, 3] = -fsaverage_middle + slice_idx return slice_affine -def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thickness_points, subdivisions, - subdivision_method, contour_smoothing, vox_size): +def process_slice( + segmentation: np.ndarray, + slice_idx: int, + ac_coords: np.ndarray, + pc_coords: np.ndarray, + affine: np.ndarray, + num_thickness_points: int, + subdivisions: list[float], + subdivision_method: str, + contour_smoothing: float, + vox_size: float +) -> dict | None: """Process a single slice for corpus callosum measurements. - - Performs detailed analysis of a corpus callosum slice, including: - - Contour extraction and endpoint detection - - Thickness profile calculation - - Area and perimeter measurements - - Shape-based metrics (circularity, CC index) - - Subdivision into anatomical regions - - Args: - segmentation (np.ndarray): - 3D segmentation array - slice_idx (int): - Index of the slice to process - ac_coords (np.ndarray): - Anterior commissure coordinates - pc_coords (np.ndarray): - Posterior commissure coordinates - affine (np.ndarray): - 4x4 affine transformation matrix - num_thickness_points (int): - Number of points for thickness estimation - subdivisions (list[float]): - List of fractions for anatomical subdivisions - subdivision_method (str): - Method for contour subdivision ('shape', 'vertical', - 'angular', or 'eigenvector') - contour_smoothing (float): - Gaussian sigma for contour smoothing - Returns: - slice_data (dict | None): - Dictionary containing measurements if successful, including: - - - cc_index: Corpus callosum shape index - - circularity: Shape circularity measure - - areas: Areas of subdivided regions - - midline_length: Length along the midline - - thickness: Array of thickness measurements - - curvature: Array of curvature measurements - - thickness_profile: Thickness measurements along the contour - - total_area: Total area of the CC - - total_perimeter: Total perimeter length - - split_contours: Subdivided contour segments - - split_contours_hofer_frahm: Alternative subdivision (if applicable) - - midline_equidistant: Equidistant points along midline - - levelpaths: Paths for thickness measurements - - thickness_measurement_points: Points where thickness was measured - - slice_index: Index of the processed slice - - Returns None if no CC is found in the slice. - + + Parameters + ---------- + segmentation : np.ndarray + 3D segmentation array. + slice_idx : int + Index of the slice to process. + ac_coords : np.ndarray + Anterior commissure coordinates. + pc_coords : np.ndarray + Posterior commissure coordinates. + affine : np.ndarray + 4x4 affine transformation matrix. + num_thickness_points : int + Number of points for thickness estimation. + subdivisions : list[float] + List of fractions for anatomical subdivisions. + subdivision_method : str + Method for contour subdivision ('shape', 'vertical', 'angular', or 'eigenvector'). + contour_smoothing : float + Gaussian sigma for contour smoothing. + vox_size : float + Voxel size in millimeters. + + Returns + ------- + dict | None + Dictionary containing measurements if successful, including: + - cc_index : float - Corpus callosum shape index. + - circularity : float - Shape circularity measure. + - areas : np.ndarray - Areas of subdivided regions. + - midline_length : float - Length along the midline. + - thickness : np.ndarray - Array of thickness measurements. + - curvature : np.ndarray - Array of curvature measurements. + - thickness_profile : list[float] - Thickness measurements along the contour. + - total_area : float - Total area of the CC. + - total_perimeter : float - Total perimeter length. + - split_contours : list[np.ndarray] - Subdivided contour segments. + - split_contours_hofer_frahm : list[np.ndarray] - Alternative subdivision (if applicable). + - midline_equidistant : np.ndarray - Equidistant points along midline. + - levelpaths : list[np.ndarray] - Paths for thickness measurements. + - thickness_measurement_points : np.ndarray - Points where thickness was measured. + - slice_index : int - Index of the processed slice. + Returns None if no CC is found in the slice. + + Raises + ------ + ValueError + If no CC is found in the specified slice. + + Notes + ----- + The function performs the following steps: + 1. Extracts CC contour and identifies endpoints. + 2. Converts coordinates to RAS space. + 3. Calculates thickness profile using Laplace equation. + 4. Computes shape metrics and subdivisions. + 5. Generates visualization data. + """ cc_mask_slice = segmentation[slice_idx] == 192 @@ -241,37 +277,82 @@ def process_slice(segmentation, slice_idx, ac_coords, pc_coords, affine, num_thi }, contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx -def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac_coords, pc_coords, - num_thickness_points, subdivisions, subdivision_method, contour_smoothing, - debug_image_path=None, one_debug_image=False, - thickness_image_path=None, vox_size=None, vox2ras_tkr=None, - save_template=None, surf_file_path=None, overlay_file_path=None, cc_html_path=None, - vtk_file_path=None, verbose=False): +def process_slices( + segmentation: np.ndarray, + slice_selection: str, + temp_seg_affine: np.ndarray, + midslices: np.ndarray, + ac_coords: np.ndarray, + pc_coords: np.ndarray, + num_thickness_points: int, + subdivisions: list[float], + subdivision_method: str, + contour_smoothing: float, + debug_image_path: str | None = None, + one_debug_image: bool = False, + thickness_image_path: str | None = None, + vox_size: tuple[float, float, float] | None = None, + vox2ras_tkr: np.ndarray | None = None, + save_template: str | Path | None = None, + surf_file_path: str | None = None, + overlay_file_path: str | None = None, + cc_html_path: str | None = None, + vtk_file_path: str | None = None, + verbose: bool = False +) -> tuple[list, list]: """Process corpus callosum slices based on selection mode. - - Handles the processing of either a single middle slice, all slices, or a specific slice, - including affine transformations and measurements for each slice. - - Args: - segmentation (np.ndarray): 3D segmentation array - slice_selection (str): Which slices to process ('middle', 'all', or slice number) - temp_seg_affine (np.ndarray): Base affine transformation matrix - midslices (np.ndarray): Array of mid-sagittal slices - ac_coords (np.ndarray): Anterior commissure coordinates - pc_coords (np.ndarray): Posterior commissure coordinates - num_thickness_points (int): Number of points for thickness estimation - subdivisions (list[float]): List of fractions for anatomical subdivisions - subdivision_method (str): Method for contour subdivision - contour_smoothing (float): Gaussian sigma for contour smoothing - debug_image_path (str, optional): Path for debug visualization image - verbose (bool): Whether to print progress information - save_template (str | Path | None): Directory path where to save template files, or None to skip saving - - Returns: - tuple: Contains: - - - list: List of slice processing results - - list: List of background IO processes + + Parameters + ---------- + segmentation : np.ndarray + 3D segmentation array. + slice_selection : str + Which slices to process ('middle', 'all', or slice number). + temp_seg_affine : np.ndarray + Base affine transformation matrix. + midslices : np.ndarray + Array of mid-sagittal slices. + ac_coords : np.ndarray + Anterior commissure coordinates. + pc_coords : np.ndarray + Posterior commissure coordinates. + num_thickness_points : int + Number of points for thickness estimation. + subdivisions : list[float] + List of fractions for anatomical subdivisions. + subdivision_method : str + Method for contour subdivision. + contour_smoothing : float + Gaussian sigma for contour smoothing. + debug_image_path : str or None, optional + Path for debug visualization image, by default None. + one_debug_image : bool, optional + Whether to save only one debug image, by default False. + thickness_image_path : str or None, optional + Path for thickness visualization image, by default None. + vox_size : tuple[float, float, float] or None, optional + Voxel size in millimeters (x, y, z), by default None. + vox2ras_tkr : np.ndarray or None, optional + Voxel to RAS tkr-space transformation matrix, by default None. + save_template : str or Path or None, optional + Directory path where to save template files, by default None. + surf_file_path : str or None, optional + Path to save surface file, by default None. + overlay_file_path : str or None, optional + Path to save overlay file, by default None. + cc_html_path : str or None, optional + Path to save HTML visualization, by default None. + vtk_file_path : str or None, optional + Path to save VTK file, by default None. + verbose : bool, optional + Whether to print progress information, by default False. + + Returns + ------- + list + List of slice processing results. + list + List of background IO processes. """ slice_results = [] IO_processes = [] @@ -424,17 +505,25 @@ def process_slices(segmentation, slice_selection, temp_seg_affine, midslices, ac -def vectorized_line_test(coords_x, coords_y, line_start, line_end): +def vectorized_line_test(coords_x: np.ndarray, coords_y: np.ndarray, + line_start: np.ndarray, line_end: np.ndarray) -> np.ndarray: """Vectorized version of point_relative_to_line for arrays of points. - - Args: - coords_x (np.ndarray): Array of x coordinates - coords_y (np.ndarray): Array of y coordinates - line_start (array-like): [x, y] coordinates of line start point - line_end (array-like): [x, y] coordinates of line end point - - Returns: - np.ndarray: Boolean array where True means point is to the left of the line + + Parameters + ---------- + coords_x : np.ndarray + Array of x coordinates. + coords_y : np.ndarray + Array of y coordinates. + line_start : array-like + [x, y] coordinates of line start point. + line_end : array-like + [x, y] coordinates of line end point. + + Returns + ------- + np.ndarray + Boolean array where True means point is to the left of the line. """ # Vector from line_start to line_end line_vec = np.array(line_end) - np.array(line_start) @@ -451,18 +540,29 @@ def vectorized_line_test(coords_x, coords_y, line_start, line_end): -def get_unique_contour_points(split_contours): +def get_unique_contour_points(split_contours: list[tuple[np.ndarray, np.ndarray]]) -> list[np.ndarray]: """Get unique contour points from the split contours. - This is a workaround to retrospectively add voxel-based sub-division - in the future we could keep track of the sub-division lines for - every sub-division scheme. - - Args: - split_contours (list): List of split contours (subsegmentations) - Returns: - list: List of unique contour points - + Parameters + ---------- + split_contours : list[tuple[np.ndarray, np.ndarray]] + List of split contours (subsegmentations), each containing x and y coordinates. + + Returns + ------- + list[np.ndarray] + List of unique contour points for each subsegment. + + Notes + ----- + This is a workaround to retrospectively add voxel-based subdivision. + In the future, we could keep track of the subdivision lines for + every subdivision scheme. + + The function: + 1. Processes each contour point. + 2. Checks if it appears in other contours (with small tolerance). + 3. Collects points unique to each subsegment. """ # For each contour point, check if it appears in other contours unique_contour_points = [] @@ -496,24 +596,37 @@ def get_unique_contour_points(split_contours): return unique_contour_points -def make_subdivision_mask(slice_shape, split_contours): +def make_subdivision_mask( + slice_shape: tuple[int, int], + split_contours: list[tuple[np.ndarray, np.ndarray]], + vox_size: tuple[float, float, float] +) -> np.ndarray: """Create a mask for subdividing the corpus callosum based on split contours. - This function creates a mask that assigns different labels to different segments of the corpus callosum - based on the subdivision lines defined by the split contours. Each segment is labeled with a value from - SUBSEGEMNT_LABELS. - - Args: - slice_shape (tuple): - Shape of the slice (rows, cols) - split_contours (list): - List of contours defining the subdivisions. - Each contour is a tuple of x and y coordinates. - - Returns: - ndarray: - A mask of shape slice_shape where each pixel is labeled with a value from SUBSEGEMNT_LABELS - indicating which subdivision segment it belongs to. + Parameters + ---------- + slice_shape : tuple[int, int] + Shape of the slice (rows, cols). + split_contours : list[tuple[np.ndarray, np.ndarray]] + List of contours defining the subdivisions. + Each contour is a tuple of x and y coordinates. + + Returns + ------- + np.ndarray + A mask of shape slice_shape where each pixel is labeled with a value + from SUBSEGEMNT_LABELS indicating which subdivision segment it belongs to. + + Notes + ----- + The function: + 1. Extracts unique contour points at subdivision boundaries. + 2. Creates coordinate grids for all points in the slice. + 3. Initializes mask with first segment label. + 4. For each subdivision line: + - Tests which points lie to the right of the line. + - Updates labels for those points. + """ # unique contour points are the points where sub-division lines were inserted @@ -527,20 +640,23 @@ def make_subdivision_mask(slice_shape, split_contours): # Create coordinate grids for all points in the slice rows, cols = slice_shape y_coords, x_coords = np.mgrid[0:rows, 0:cols] + + subsegment_labels_anterior_posterior = SUBSEGEMNT_LABELS.copy() + subsegment_labels_anterior_posterior.reverse() # Initialize with first segment label - subdivision_mask = np.full(slice_shape, SUBSEGEMNT_LABELS[0], dtype=np.int32) + subdivision_mask = np.full(slice_shape, subsegment_labels_anterior_posterior[0], dtype=np.int32) # Process each subdivision line for segment_idx, segment_points in enumerate(subdivision_segments): - line_start = segment_points[0] - line_end = segment_points[-1] + line_start = segment_points[0] / vox_size[0] + line_end = segment_points[-1] / vox_size[0] # Vectorized test: find all points to the right of this line points_right_of_line = vectorized_line_test(x_coords, y_coords, line_start, line_end) # All points to the right of this line belong to the next segment or beyond - subdivision_mask[points_right_of_line] = SUBSEGEMNT_LABELS[segment_idx + 1] + subdivision_mask[points_right_of_line] = subsegment_labels_anterior_posterior[segment_idx + 1] # Debug visualization (optional) # import matplotlib.pyplot as plt @@ -553,15 +669,22 @@ def make_subdivision_mask(slice_shape, split_contours): return subdivision_mask -def check_area_changes(contours: list[np.ndarray], threshold: float = 0.3, verbose: bool = False) -> None: +def check_area_changes(contours: list[np.ndarray], threshold: float = 0.3, verbose: bool = False) -> bool: """Check for large changes between consecutive CC areas and issue warnings. - This function checks if any two consecutive areas have a change greater than - the specified threshold (default 30%) and issues a warning if they do. - - Args: - contours (list[np.ndarray]): List of contours - threshold (float, optional): Threshold for relative change. Defaults to 0.3 (30%). + Parameters + ---------- + contours : list[np.ndarray] + List of contours. + threshold : float, optional + Threshold for relative change, by default 0.3 (30%). + verbose : bool, optional + Whether to print warnings, by default False. + + Returns + ------- + bool + True if no large area changes are detected, False otherwise. """ areas = [np.sum(np.sqrt(np.sum((np.diff(contour, axis=0))**2, axis=1))) for contour in contours] diff --git a/CorpusCallosum/shape/cc_subsegment_contour.py b/CorpusCallosum/shape/cc_subsegment_contour.py index 25bb89cc..ced67946 100644 --- a/CorpusCallosum/shape/cc_subsegment_contour.py +++ b/CorpusCallosum/shape/cc_subsegment_contour.py @@ -12,17 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable + +import matplotlib.pyplot as plt import numpy as np from scipy.spatial import ConvexHull def minimum_bounding_rectangle(points): - """ - Find the smallest bounding rectangle for a set of points. - Returns a set of points representing the corners of the bounding box. + """Find the smallest bounding rectangle for a set of points. + + Parameters + ---------- + points : np.ndarray + Array of shape (N, 2) containing point coordinates. - :param points: an nx2 matrix of coordinates - :rval: an nx2 matrix of coordinates + Returns + ------- + np.ndarray + Array of shape (4, 2) containing coordinates of the bounding box corners. """ pi2 = np.pi / 2.0 points = points.T @@ -73,6 +81,18 @@ def minimum_bounding_rectangle(points): def get_area_from_subsegments(split_contours): + """Calculate area of each subsegment using the shoelace formula. + + Parameters + ---------- + split_contours : list[np.ndarray] + List of contour arrays, each of shape (2, N). + + Returns + ------- + np.ndarray + Array containing the area of each subsegment. + """ # calculate area of each split contour using the shoelace formula areas = [np.abs(np.trapz(split_contour[1], split_contour[0])) for split_contour in split_contours] area_out = np.zeros(len(areas)) @@ -85,6 +105,33 @@ def get_area_from_subsegments(split_contours): def subsegment_midline_orthogonal(midline, area_weights, contour, plot=True, ax=None, extremes=None): + """Subsegment contour orthogonally to the midline based on area weights. + + Parameters + ---------- + midline : np.ndarray + Array of shape (N, 2) containing midline points. + area_weights : np.ndarray + Array of weights for area-based subdivision. + contour : np.ndarray + Array of shape (2, M) containing contour points. + plot : bool, optional + Whether to plot the results, by default True. + ax : matplotlib.axes.Axes, optional + Axes for plotting, by default None. + extremes : tuple, optional + Tuple of extreme points, by default None. + + Returns + ------- + split_contours : list[np.ndarray] + List of contour arrays for each subsegment. + split_points : np.ndarray + Array of split points. + edge_directions : np.ndarray + Array of edge directions at split points. + + """ # get points after midline length of splits # get vertex closest to midline end @@ -415,8 +462,53 @@ def hampel_subdivide_contour(contour, num_rays, plot=False, ax=None): def subdivide_contour( - contour, area_weights, plot=False, ax=None, plot_transform=None, oriented=False, hline_anchor=None + contour: np.ndarray, + area_weights: list[float], + plot: bool = False, + ax: plt.Axes | None = None, + plot_transform: Callable | None = None, + oriented: bool = False, + hline_anchor: np.ndarray | None = None ): + """Subdivide contour based on area weights using vertical lines. + + Divides the contour into segments by drawing vertical lines at positions + determined by the area weights. The lines are drawn perpendicular to a + reference line connecting the extreme points of the contour. + + Parameters + ---------- + contour : np.ndarray + Array of shape (2, N) containing contour points. + area_weights : list[float] + List of weights for area-based subdivision. + plot : bool, optional + Whether to plot the results, by default False. + ax : matplotlib.axes.Axes, optional + Axes for plotting, by default None. + plot_transform : callable, optional + Function to transform points before plotting, by default None. + oriented : bool, optional + If True, use fixed horizontal reference line, by default False. + hline_anchor : np.ndarray, optional + Point to anchor horizontal reference line, by default None. + + Returns + ------- + areas : np.ndarray + Array of areas for each subsegment. + split_contours : list[np.ndarray] + List of contour arrays for each subsegment. + + Notes + ----- + The subdivision process: + 1. Finds extreme points in x-direction. + 2. Creates reference line between extremes. + 3. Calculates split points based on area weights. + 4. Divides contour using perpendicular lines at split points. + + """ # Find the extreme points in the x-direction min_x_index = np.argmax(contour[0]) contour = np.roll(contour, -min_x_index, axis=1) @@ -712,6 +804,34 @@ def subdivide_contour( def transform_to_acpc_standard(contour_ras, ac_pt_ras, pc_pt_ras): + """Transform contour coordinates to AC-PC standard space. + + Transforms the contour coordinates by: + 1. Translating AC point to origin. + 2. Rotating to align PC point with posterior direction. + 3. Scaling to maintain AC-PC distance. + + Parameters + ---------- + contour_ras : np.ndarray + Array of shape (2, N) or (3, N) containing contour points in RAS space. + ac_pt_ras : np.ndarray + Anterior commissure point coordinates in RAS space. + pc_pt_ras : np.ndarray + Posterior commissure point coordinates in RAS space. + + Returns + ------- + contour_acpc : np.ndarray + Transformed contour points in AC-PC space. + ac_pt_acpc : np.ndarray + AC point in AC-PC space (origin). + pc_pt_acpc : np.ndarray + PC point in AC-PC space. + rotate_back : callable + Function to transform points back to RAS space. + + """ # translate AC to the origin and PC to (0, ac_pc_dist) translation_matrix = np.array([[1, 0, -ac_pt_ras[0]], [0, 1, -ac_pt_ras[1]], [0, 0, 1]]) @@ -749,6 +869,27 @@ def rotate_back(x): def preprocess_cc(cc_label_nib, paths_csv, subj_id): + """Preprocess corpus callosum mask and extract AC/PC coordinates. + + Parameters + ---------- + cc_label_nib : nibabel.Nifti1Image + NIfTI image containing corpus callosum segmentation. + paths_csv : pd.DataFrame + DataFrame containing AC and PC coordinates. + subj_id : str + Subject ID to look up in paths_csv. + + Returns + ------- + cc_mask : np.ndarray + Binary mask of corpus callosum. + AC_2d : np.ndarray + 2D coordinates of anterior commissure. + PC_2d : np.ndarray + 2D coordinates of posterior commissure. + + """ cc_mask = cc_label_nib.get_fdata() == 192 cc_mask = cc_mask[cc_mask.shape[0] // 2] @@ -768,6 +909,27 @@ def preprocess_cc(cc_label_nib, paths_csv, subj_id): def get_primary_eigenvector(contour_ras): + """Calculate primary eigenvector of contour points using PCA. + + Computes the principal direction of the contour by: + 1. Centering the points + 2. Computing covariance matrix + 3. Finding eigenvectors + 4. Selecting primary direction + + Parameters + ---------- + contour_ras : np.ndarray + Array of shape (2, N) containing contour points in RAS space. + + Returns + ------- + pt0 : np.ndarray + Starting point for eigenvector line. + pt1 : np.ndarray + End point for eigenvector line. + + """ # Center the data by subtracting mean contour_mean = np.mean(contour_ras, axis=1, keepdims=True) contour_centered = contour_ras - contour_mean diff --git a/CorpusCallosum/shape/cc_thickness.py b/CorpusCallosum/shape/cc_thickness.py index 478eae16..788101b2 100644 --- a/CorpusCallosum/shape/cc_thickness.py +++ b/CorpusCallosum/shape/cc_thickness.py @@ -21,7 +21,19 @@ from CorpusCallosum.utils.utils import HiddenPrints -def compute_curvature(path): +def compute_curvature(path: np.ndarray) -> np.ndarray: + """Compute curvature by computing edge angles. + + Parameters + ---------- + path : np.ndarray + Array of shape (N, 2) containing path coordinates. + + Returns + ------- + np.ndarray + Array of angle differences between consecutive edges. + """ # compute curvature by computing edge angles edges = np.diff(path, axis=0) angles = np.arctan2(edges[:, 1], edges[:, 0]) @@ -32,7 +44,33 @@ def compute_curvature(path): return angle_diffs -def convert_to_ras(contour, vox2ras_matrix, get_parameters=False): +def convert_to_ras( + contour: np.ndarray, + vox2ras_matrix: np.ndarray, + get_parameters: bool = False +) -> np.ndarray | tuple[np.ndarray, bool, bool, bool]: + """Convert contour coordinates from voxel space to RAS space. + + Parameters + ---------- + contour : np.ndarray + Array of shape (2, N) or (3, N) containing contour coordinates. + vox2ras_matrix : np.ndarray + 4x4 voxel to RAS transformation matrix. + get_parameters : bool, optional + If True, return additional transformation parameters, by default False. + + Returns + ------- + np.ndarray | tuple[np.ndarray, bool, bool, bool] + If get_parameters is False: + Transformed contour coordinates. + If get_parameters is True: + - anterior_reversed : bool - Whether anterior axis was reversed. + - superior_reversed : bool - Whether superior axis was reversed. + - swap_axes : bool - Whether axes were swapped. + + """ # converting to AS (no left-right dimension), out of plane movement is ignores, # so we only do scaling, axes swapping and flipping - no rotation # translation is ignored @@ -81,6 +119,29 @@ def convert_to_ras(contour, vox2ras_matrix, get_parameters=False): def set_contour_zero_idx(contour, idx, anterior_endpoint_idx, posterior_endpoint_idx): + """Roll contour points to set a new zero index, while keeping track of CC endpoints. + + Parameters + ---------- + contour : np.ndarray + Array of contour points. + idx : int + New zero index. + anterior_endpoint_idx : int + Index of anterior endpoint. + posterior_endpoint_idx : int + Index of posterior endpoint. + + Returns + ------- + tuple + - contour : np.ndarray + Rolled contour points. + - anterior_endpoint_idx : int + Updated anterior endpoint index. + - posterior_endpoint_idx : int + Updated posterior endpoint index. + """ contour = np.roll(contour, -idx, axis=0) anterior_endpoint_idx = (anterior_endpoint_idx - idx) % contour.shape[0] posterior_endpoint_idx = (posterior_endpoint_idx - idx) % contour.shape[0] @@ -90,12 +151,17 @@ def set_contour_zero_idx(contour, idx, anterior_endpoint_idx, posterior_endpoint def find_closest_edge(point, contour): """Find the index of the edge closest to the given point. - Args: - point: 2D point coordinates - contour: Array of contour points (N x 2) - - Returns: - Index of the closest edge + Parameters + ---------- + point : np.ndarray + 2D point coordinates. + contour : np.ndarray + Array of shape (N, 2) containing contour points. + + Returns + ------- + int + Index of the closest edge. """ edges_start = contour[:-1, :2] # N-1 x 2 edges_end = contour[1:, :2] # N-1 x 2 @@ -123,18 +189,33 @@ def find_closest_edge(point, contour): return np.argmin(distances) -def insert_point_to_contour( - contour_with_thickness, point, thickness_value, get_index=False -): +def insert_point_with_thickness( + contour_with_thickness: list[np.ndarray], + point: np.ndarray, + thickness_value: float, + get_index: bool = False +) -> tuple[list[np.ndarray], int] | list[np.ndarray]: """Insert a point and its thickness value into the contour. - Args: - contour_with_thickness: List containing [contour_points, thickness_values] - point: 2D point to insert - thickness_value: Thickness value corresponding to the point - - Returns: - Updated contour_with_thickness + Parameters + ---------- + contour_with_thickness : list[np.ndarray] + List containing [contour_points, thickness_values]. + point : np.ndarray + 2D point to insert, shape (2,). + thickness_value : float + Thickness value corresponding to the point. + get_index : bool, optional + If True, return the index where point was inserted, by default False. + + Returns + ------- + tuple[list[np.ndarray], int] or list[np.ndarray] + If get_index is True: + - Updated contour_with_thickness. + - Index where point was inserted. + If get_index is False: + - Updated contour_with_thickness. """ # Find closest edge for the point edge_idx = find_closest_edge(point, contour_with_thickness[0]) @@ -153,7 +234,36 @@ def insert_point_to_contour( return contour_with_thickness -def make_mesh_from_contour(contour_2d, max_volume=0.5, min_angle=25, verbose=False): +def make_mesh_from_contour( + contour_2d: np.ndarray, + max_volume: float = 0.5, + min_angle: float = 25, + verbose: bool = False +) -> tuple[np.ndarray, np.ndarray]: + """Create a triangular mesh from a 2D contour. + + Parameters + ---------- + contour_2d : np.ndarray + Array of shape (N, 2) containing contour points. + max_volume : float, optional + Maximum triangle area, by default 0.5. + min_angle : float, optional + Minimum angle in triangles (degrees), by default 25. + verbose : bool, optional + Whether to print mesh generation info, by default False. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + - mesh_points : Array of shape (M, 2) containing mesh vertices. + - mesh_trias : Array of shape (K, 3) containing triangle indices. + + Notes + ----- + Uses meshpy.triangle to create a constrained Delaunay triangulation + of the contour. The contour must not have duplicate points. + """ facets = np.vstack( ( @@ -186,8 +296,38 @@ def make_mesh_from_contour(contour_2d, max_volume=0.5, min_angle=25, verbose=Fal def cc_thickness( - contour_2d, anterior_endpoint_idx, posterior_endpoint_idx, n_points=100 -): + contour_2d: np.ndarray, + anterior_endpoint_idx: int, + posterior_endpoint_idx: int, + n_points: int = 100 +) -> tuple[np.ndarray, np.ndarray]: + """Calculate corpus callosum thickness using Laplace equation. + + Parameters + ---------- + contour_2d : np.ndarray + Array of shape (N, 2) containing contour points. + anterior_endpoint_idx : int + Index of anterior endpoint in contour. + posterior_endpoint_idx : int + Index of posterior endpoint in contour. + n_points : int, optional + Number of points for thickness measurement, by default 100. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + - thickness_values : Array of thickness measurements. + - measurement_points : Array of points where thickness was measured. + + Notes + ----- + Uses the Laplace equation to compute thickness by: + 1. Creating a triangular mesh from the contour + 2. Setting boundary conditions (0 at endpoints, ±1 on sides) + 3. Solving Laplace equation to get level sets + 4. Computing thickness along level sets + """ # standardize contour indices, to get consistent levelpath directions contour_2d, anterior_endpoint_idx, posterior_endpoint_idx = set_contour_zero_idx( @@ -287,10 +427,10 @@ def cc_thickness( levelpath_start = lvlpath[0, :2] levelpath_end = lvlpath[-1, :2] - contour_with_thickness, inserted_idx_start = insert_point_to_contour( + contour_with_thickness, inserted_idx_start = insert_point_with_thickness( contour_with_thickness, levelpath_start, lvlpath_length, get_index=True ) - contour_with_thickness, inserted_idx_end = insert_point_to_contour( + contour_with_thickness, inserted_idx_end = insert_point_with_thickness( contour_with_thickness, levelpath_end, lvlpath_length, get_index=True ) diff --git a/CorpusCallosum/transforms/localization_transforms.py b/CorpusCallosum/transforms/localization_transforms.py index 2f106beb..a02b7e89 100644 --- a/CorpusCallosum/transforms/localization_transforms.py +++ b/CorpusCallosum/transforms/localization_transforms.py @@ -17,19 +17,68 @@ class CropAroundACPCFixedSize(RandomizableTransform, MapTransform): - """ - Crop around AC and PC with fixed size + """Crop image around AC-PC points with fixed size. + + A transform that crops the input image around the midpoint between + AC and PC points with a fixed size window and optional random translation. + + Parameters + ---------- + keys : list[str] + Keys of the data dictionary to apply the transform to + fixed_size : tuple[int, int] + Fixed size of the crop window (width, height) + allow_missing_keys : bool, optional + Whether to allow missing keys in the data dictionary, by default False + random_translate : float, optional + Maximum random translation in voxels, by default 0 + + Notes + ----- + The transform expects the following keys in the data dictionary: + - AC_center : np.ndarray + Coordinates of anterior commissure + - PC_center : np.ndarray + Coordinates of posterior commissure + - image : np.ndarray + Input image to crop + + Raises + ------ + ValueError + If the crop boundaries extend outside the image dimensions """ - def __init__(self, keys, fixed_size: tuple[int, int], allow_missing_keys: bool = False, - random_translate: float = 0) -> None: + def __init__(self, keys: list[str], fixed_size: tuple[int, int], + allow_missing_keys: bool = False, random_translate: float = 0) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self) self.random_translate = random_translate self.fixed_size = fixed_size - - def __call__(self, data): + def __call__(self, data: dict) -> dict: + """Apply the transform to the data. + + Parameters + ---------- + data : dict + Dictionary containing the data to transform + + Returns + ------- + dict + Transformed data dictionary with cropped images and updated coordinates. + Also includes crop boundary information: + - crop_left : int + - crop_right : int + - crop_top : int + - crop_bottom : int + + Raises + ------ + ValueError + If crop boundaries extend outside the image dimensions + """ d = dict(data) for key in self.keys: diff --git a/CorpusCallosum/transforms/segmentation_transforms.py b/CorpusCallosum/transforms/segmentation_transforms.py index 565b1dfe..051bc5b4 100644 --- a/CorpusCallosum/transforms/segmentation_transforms.py +++ b/CorpusCallosum/transforms/segmentation_transforms.py @@ -17,20 +17,54 @@ class CropAroundACPC(RandomizableTransform, MapTransform): - """ - Crop around AC and PC + """Crop image around anterior and posterior commissure points. + + A transform that crops the input image around the AC and PC points with + optional padding and random translation. + + Parameters + ---------- + keys : list[str] + Keys of the data dictionary to apply the transform to + allow_missing_keys : bool, optional + Whether to allow missing keys in the data dictionary, by default False + padding_mm : float, optional + Padding around AC-PC region in millimeters, by default 10 + random_translate : float, optional + Maximum random translation in voxels, by default 0 + + Notes + ----- + The transform expects the following keys in the data dictionary: + - AC_center : np.ndarray + Coordinates of anterior commissure + - PC_center : np.ndarray + Coordinates of posterior commissure + - res : float + Voxel resolution in mm """ - def __init__(self, keys, allow_missing_keys: bool = False, padding_mm: float = 10, - random_translate: float = 0) -> None: + def __init__(self, keys: list[str], allow_missing_keys: bool = False, + padding_mm: float = 10, random_translate: float = 0) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob=1, do_transform=True) self.padding_mm = padding_mm self.random_translate = random_translate - def __call__(self, data): - d = dict(data) + def __call__(self, data: dict) -> dict: + """Apply the transform to the data. + Parameters + ---------- + data : dict + Dictionary containing the data to transform + + Returns + ------- + dict + Transformed data dictionary + """ + d = dict(data) if 'AC_center_original' not in d: d['AC_center_original'] = d['AC_center'].copy() @@ -49,13 +83,10 @@ def __call__(self, data): pc_center = d['PC_center'] ac_center = d['AC_center'] - # 'PC_center': array([ 2., 139., 143.], dtype=float32), 'AC_center': array([ 2., 128., 168.] - ac_pc_bottomleft = (np.min([ac_center[1], pc_center[1]]).astype(int), - np.min([ac_center[2], pc_center[2]]).astype(int)) + np.min([ac_center[2], pc_center[2]]).astype(int)) ac_pc_topright = (np.max([ac_center[1], pc_center[1]]).astype(int), - np.max([ac_center[2], pc_center[2]]).astype(int)) - + np.max([ac_center[2], pc_center[2]]).astype(int)) voxel_padding = round(self.padding_mm / d['res']) @@ -64,25 +95,55 @@ def __call__(self, data): crop_top = ac_pc_bottomleft[1]-voxel_padding+random_translate[1] crop_bottom = ac_pc_topright[1]+voxel_padding+random_translate[1] - d['to_pad'] = crop_left, d[key].shape[2]-crop_right, crop_top, d[key].shape[3]-crop_bottom d[key] = d[key][:, :, crop_left:crop_right, crop_top:crop_bottom] - - #d[key] = d[key][:, d[key].shape[1]//2-voxel_padding:d[key].shape[2]//2+voxel_padding] - - #print('cropped', d[key].shape, 'for key', key) - return d + class CropAroundACPCtrack(CropAroundACPC): + """Crop image around AC-PC points and update their coordinates. + + Extends CropAroundACPC to also adjust the AC and PC center coordinates + after cropping to maintain their correct positions in the cropped image. + + Parameters + ---------- + keys : list[str] + Keys of the data dictionary to apply the transform to + allow_missing_keys : bool, optional + Whether to allow missing keys in the data dictionary, by default False + padding_mm : float, optional + Padding around AC-PC region in millimeters, by default 10 + random_translate : float, optional + Maximum random translation in voxels, by default 0 + + Notes + ----- + The transform expects the following keys in the data dictionary: + - AC_center : np.ndarray + Coordinates of anterior commissure + - PC_center : np.ndarray + Coordinates of posterior commissure + - AC_center_original : np.ndarray + Original coordinates of anterior commissure + - PC_center_original : np.ndarray + Original coordinates of posterior commissure """ - Same as crop around ACPC but also adjusts AC_center and PC_center accordingly - - - """ - def __call__(self, data): + def __call__(self, data: dict) -> dict: + """Apply the transform to the data. + + Parameters + ---------- + data : dict + Dictionary containing the data to transform + + Returns + ------- + dict + Transformed data dictionary with updated AC and PC coordinates + """ # First call parent class to get cropped image diff --git a/CorpusCallosum/utils/utils.py b/CorpusCallosum/utils/utils.py index 4b79e8e6..15aa0714 100644 --- a/CorpusCallosum/utils/utils.py +++ b/CorpusCallosum/utils/utils.py @@ -17,10 +17,44 @@ class HiddenPrints: - def __enter__(self): + """Context manager for suppressing stdout output. + + Temporarily redirects stdout to os.devnull to hide any print statements + within the context. + + Examples + -------- + >>> with HiddenPrints(): + ... print("This will not be visible") + >>> print("This will be visible") + """ + + def __enter__(self) -> None: + """Enter the context manager. + + Returns + ------- + None + """ self._original_stdout = sys.stdout sys.stdout = open(os.devnull, "w") - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: type | None, exc_val: Exception | None, + exc_tb: type | None) -> None: + """Exit the context manager. + + Parameters + ---------- + exc_type : type or None + Type of the exception that occurred, if any + exc_val : Exception or None + Exception instance that occurred, if any + exc_tb : type or None + Traceback of the exception that occurred, if any + + Returns + ------- + None + """ sys.stdout.close() sys.stdout = self._original_stdout \ No newline at end of file diff --git a/CorpusCallosum/visualization/visualization.py b/CorpusCallosum/visualization/visualization.py index 3574045b..a0f93e12 100644 --- a/CorpusCallosum/visualization/visualization.py +++ b/CorpusCallosum/visualization/visualization.py @@ -15,17 +15,35 @@ from pathlib import Path import matplotlib.pyplot as plt +import nibabel as nib import numpy as np -def plot_standardized_space(ax_row, vol, ac_coords, pc_coords): +def plot_standardized_space( + ax_row: list[plt.Axes], + vol: np.ndarray, + ac_coords: np.ndarray, + pc_coords: np.ndarray +) -> None: """Plot standardized space visualization across three views. - Args: - ax_row: Row of axes to plot on (should be length 3) - vol: Volume data to visualize - ac_coords: AC coordinates in standardized space - pc_coords: PC coordinates in standardized space + Parameters + ---------- + ax_row : list[plt.Axes] + Row of axes to plot on (should be length 3) + vol : np.ndarray + Volume data to visualize + ac_coords : np.ndarray + AC coordinates in standardized space + pc_coords : np.ndarray + PC coordinates in standardized space + + Notes + ----- + Creates three views: + - Axial (top view) + - Sagittal (side view) + - Coronal (front view) """ ax_row[0].set_title("Standardized") @@ -46,28 +64,48 @@ def plot_standardized_space(ax_row, vol, ac_coords, pc_coords): def visualize_coordinate_spaces( - orig, - upright, - standardized, - ac_coords_orig, - pc_coords_orig, - ac_coords_3d, - pc_coords_3d, - ac_coords_standardized, - pc_coords_standardized, - output_dir, -): - """ - Visualize the AC and PC coordinates in different coordinate spaces for testing/debugging. - - Args: - orig: Original image volume - vol: Volume in fsaverage space - vol2: Volume after nodding correction - vol3: Volume after translation - ac_coords_*: AC coordinates in different spaces - pc_coords_*: PC coordinates in different spaces - output_dir: Directory to save visualization + orig: "nib.Nifti1Image", + upright: np.ndarray, + standardized: np.ndarray, + ac_coords_orig: np.ndarray, + pc_coords_orig: np.ndarray, + ac_coords_3d: np.ndarray, + pc_coords_3d: np.ndarray, + ac_coords_standardized: np.ndarray, + pc_coords_standardized: np.ndarray, + output_dir: str | Path, +) -> None: + """Visualize the AC and PC coordinates in different coordinate spaces. + + Creates a figure showing the anterior and posterior commissure points + in three different coordinate spaces for testing/debugging. + + Parameters + ---------- + orig : nibabel.Nifti1Image + Original image volume + upright : np.ndarray + Volume in fsaverage space + standardized : np.ndarray + Volume in standardized space + ac_coords_orig : np.ndarray + AC coordinates in original space + pc_coords_orig : np.ndarray + PC coordinates in original space + ac_coords_3d : np.ndarray + AC coordinates in fsaverage space + pc_coords_3d : np.ndarray + PC coordinates in fsaverage space + ac_coords_standardized : np.ndarray + AC coordinates in standardized space + pc_coords_standardized : np.ndarray + PC coordinates in standardized space + output_dir : str or Path + Directory to save visualization + + Notes + ----- + Saves the visualization as 'ac_pc_spaces.png' in the output directory. """ fig, ax = plt.subplots(3, 4) ax = ax.T @@ -95,32 +133,48 @@ def visualize_coordinate_spaces( def plot_contours( transformed: np.ndarray, - split_contours: list[np.ndarray], - split_contours_hofer_frahm: list[np.ndarray], - midline_equidistant: np.ndarray, - levelpaths: list[np.ndarray], - output_path: str, - ac_coords: np.ndarray, - pc_coords: np.ndarray, - vox_size: float, - title: str = None, + split_contours: list[np.ndarray] | None = None, + split_contours_hofer_frahm: list[np.ndarray] | None = None, + midline_equidistant: np.ndarray | None = None, + levelpaths: list[np.ndarray] | None = None, + output_path: str | Path | None = None, + ac_coords: np.ndarray | None = None, + pc_coords: np.ndarray | None = None, + vox_size: float | None = None, + title: str = "", + debug: bool = False, ) -> None: - """Plots corpus callosum contours and segmentations. - - Creates a figure with three subplots showing: - 1. Midline-based subsegmentation - 2. Hofer-Frahm segmentation scheme - 3. Midline and levelpaths visualization - - Args: - transformed: The transformed brain image array - split_contours: List of contour arrays for midline-based segmentation - split_contours_hofer_frahm: List of contour arrays for Hofer-Frahm segmentation - midline_equidistant: Array of midline points - levelpaths: List of levelpath arrays - output_dir: Directory to save the output plot - ac_coords: Anterior commissure coordinates - pc_coords: Posterior commissure coordinates + """Plot contours and subdivisions of the corpus callosum. + + Parameters + ---------- + transformed : np.ndarray + Transformed image data + split_contours : list[np.ndarray], optional + List of contour arrays for each subdivision, by default None + split_contours_hofer_frahm : list[np.ndarray], optional + List of contour arrays using Hofer-Frahm subdivision, by default None + midline_equidistant : np.ndarray, optional + Midline points at equidistant spacing, by default None + levelpaths : list[np.ndarray], optional + List of level paths for visualization, by default None + output_path : str or Path, optional + Path to save the plot, by default None + ac_coords : np.ndarray, optional + AC coordinates for visualization, by default None + pc_coords : np.ndarray, optional + PC coordinates for visualization, by default None + vox_size : float, optional + Voxel size for scaling, by default None + title : str, optional + Title for the plot, by default "" + debug : bool, optional + Whether to show debug information, by default False + + Notes + ----- + Creates a visualization of the corpus callosum contours and their subdivisions. + If output_path is provided, saves the plot to that location. """ # scale contour data by vox_size @@ -206,13 +260,23 @@ def plot_contours( # plt.show() -def plot_midplane(grid_orig, orig): - """ - Creates a 3D visualization of grid points in original image space. +def plot_midplane(grid_orig: np.ndarray, orig: np.ndarray) -> None: + """Create a 3D visualization of grid points in original image space. + + Parameters + ---------- + grid_orig : np.ndarray + Grid points in original space, shape (3, N) + orig : np.ndarray + Original image for dimension reference - Args: - grid_orig: Grid points in original space - orig: Original image for dimension reference + Notes + ----- + The function: + 1. Creates a 3D scatter plot of grid points + 2. Samples every 40th point to avoid overcrowding + 3. Sets axis limits based on original image dimensions + 4. Shows the plot interactively """ # Create a figure showing grid points in original space From 825a9a855da293b2adfb3883c95fc2c7736d6daf Mon Sep 17 00:00:00 2001 From: ClePol Date: Thu, 2 Oct 2025 15:23:31 +0200 Subject: [PATCH 21/68] cleaned up stats writing --- CorpusCallosum/README.md | 6 +- CorpusCallosum/fastsurfer_cc.py | 158 ++++++++---------- CorpusCallosum/shape/cc_postprocessing.py | 25 +-- CorpusCallosum/visualization/visualization.py | 76 +++------ 4 files changed, 98 insertions(+), 167 deletions(-) diff --git a/CorpusCallosum/README.md b/CorpusCallosum/README.md index 2863569c..3c47d282 100644 --- a/CorpusCallosum/README.md +++ b/CorpusCallosum/README.md @@ -160,7 +160,6 @@ This file contains measurements from the middle sagittal slice and includes: **Subdivisions** - `areas`: Areas of CC using an improved Hofer-Frahm sub-division method (mm²). This gives more consistent sub-segemnts while preserving the original ratios. -- `areas_hofer_frahm`: Areas using classical Hofer-Frahm subdivision method (mm²) **Thickness Analysis:** - `thickness`: Average corpus callosum thickness (mm) @@ -179,9 +178,6 @@ This file contains measurements from the middle sagittal slice and includes: - `ac_center_upright`: AC coordinates in upright space (cc_up.lta) - `pc_center_upright`: PC coordinates in upright space (cc_up.lta) -**Processing Parameters:** -- `num_slices`: Number of slices analyzed around the midplane - ### `stats/callosum.CC.all_slices.json` (Multi-Slice Analysis) This file contains comprehensive per-slice analysis when using `--slice_selection all`: @@ -191,7 +187,7 @@ This file contains comprehensive per-slice analysis when using `--slice_selectio - `voxel_size`: Voxel dimensions [x, y, z] in mm - `subdivision_method`: Method used for anatomical subdivision - `num_thickness_points`: Number of points used for thickness estimation -- `subdivisions`: Subdivision fractions used for regional analysis +- `subdivision_ratios`: Subdivision fractions used for regional analysis - `contour_smoothing`: Gaussian sigma used for contour smoothing - `slice_selection`: Slice selection mode used diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index e0354cb0..d5b02f26 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -737,62 +737,59 @@ def main( ) ) + # Save middle slice visualization + IO_processes.append( + create_visualization( + subdivision_method, + { + "split_contours": middle_slice_result["split_contours"], + "midline_equidistant": middle_slice_result["midline_equidistant"], + "levelpaths": middle_slice_result["levelpaths"], + }, + midslices, + output_dir, + ac_coords, + pc_coords, + orig.header.get_zooms()[0], + " (Middle Slice)", + ) + ) + + save_nifti_background(IO_processes, segmentation, seg_affine, orig.header, segmentation_path) + + + METRICS = [ + "areas", + "thickness", + "curvature", + "midline_length", + "circularity", + "cc_index", + "total_area", + "total_perimeter", + "thickness_profile", + ] + + # Record key metrics for middle slice + output_metrics_middle_slice = { + metric: middle_slice_result[metric] for metric in METRICS + } + # Create enhanced output dictionary with all slice results per_slice_output_dict = { "slices": [ convert_numpy_to_json_serializable( { - "slice_index": result["slice_index"], - "cc_index": result["cc_index"], - "circularity": result["circularity"], - "areas": result["areas"], - "midline_length": result["midline_length"], - "thickness": result["thickness"], - "curvature": result["curvature"], - "thickness_profile": result["thickness_profile"], - "total_area": result["total_area"], - "total_perimeter": result["total_perimeter"], + metric: result[metric] for metric in METRICS } ) for result in slice_results ], - "slices_in_segmentation": segmentation.shape[0], - "voxel_size": [float(x) for x in orig.header.get_zooms()], - "subdivision_method": subdivision_method, - "num_thickness_points": num_thickness_points, - "subdivisions": subdivisions, - "contour_smoothing": contour_smoothing, - "slice_selection": slice_selection, } - # Save slice-wise postprocessing results to JSON - with open(postproc_results_path, "w") as f: - json.dump(per_slice_output_dict, f, indent=4) - - if verbose: - logger.info(f"Multiple slice post-processing results saved to {postproc_results_path}") - ########## Save outputs ########## - - - - # Create backward compatible output_dict for existing pipeline using middle slice - output_dict = { - "areas": middle_slice_result["areas"], - "areas_hofer_frahm": middle_slice_result["areas"] - if middle_slice_result["split_contours_hofer_frahm"] is not None - else middle_slice_result["areas"], - "thickness": middle_slice_result["thickness"], - "curvature": middle_slice_result["curvature"], - "midline_length": middle_slice_result["midline_length"], - "circularity": middle_slice_result["circularity"], - "cc_index": middle_slice_result["cc_index"], - "total_area": middle_slice_result["total_area"], - "total_perimeter": middle_slice_result["total_perimeter"], - "thickness_profile": middle_slice_result["thickness_profile"], - } - + additional_metrics = {} if len(outer_contours) > 1: cc_volume_voxel = segmentation_postprocessing.get_cc_volume_voxel( desired_width_mm=5, @@ -807,42 +804,10 @@ def main( logger.info(f"CC volume voxel: {cc_volume_voxel}") logger.info(f"CC volume contour: {cc_volume_contour}") - output_dict["cc_5mm_volume"] = cc_volume_voxel - output_dict["cc_5mm_volume_pv_corrected"] = cc_volume_contour - - # multiply split contour with resolution scale factor for middle slice visualization - split_contours = [ - split_contour * orig.header.get_zooms()[1] for split_contour in middle_slice_result["split_contours"] - ] - if middle_slice_result["split_contours_hofer_frahm"] is not None: - split_contours_hofer_frahm = [ - split_contour * orig.header.get_zooms()[1] - for split_contour in middle_slice_result["split_contours_hofer_frahm"] - ] - else: - split_contours_hofer_frahm = split_contours # backward compatibility - midline_equidistant = middle_slice_result["midline_equidistant"] * orig.header.get_zooms()[1] - levelpaths = [levelpath * orig.header.get_zooms()[1] for levelpath in middle_slice_result["levelpaths"]] + additional_metrics["cc_5mm_volume"] = cc_volume_voxel + additional_metrics["cc_5mm_volume_pv_corrected"] = cc_volume_contour - # Save middle slice visualization - single_slice_result = { - "split_contours": split_contours, - "split_contours_hofer_frahm": split_contours_hofer_frahm, - "midline_equidistant": midline_equidistant, - "levelpaths": levelpaths, - } - IO_processes.append( - create_visualization( - subdivision_method, - single_slice_result, - midslices, - output_dir, - ac_coords, - pc_coords, - orig.header.get_zooms()[0], - " (Middle Slice)", - ) - ) + # get ac and pc in all spaces ac_coords_3d = np.hstack((FSAVERAGE_MIDDLE, ac_coords)) @@ -851,24 +816,37 @@ def main( get_mapping_to_standard_space(orig, ac_coords_3d, pc_coords_3d, orig_fsaverage_vox2vox) ) - - save_nifti_background(IO_processes, segmentation, seg_affine, orig.header, segmentation_path) - # write output dict as csv - output_dict["ac_center"] = ac_coords_orig - output_dict["pc_center"] = pc_coords_orig - output_dict["ac_center_oriented_volume"] = ac_coords_standardized - output_dict["pc_center_oriented_volume"] = pc_coords_standardized - output_dict["ac_center_upright"] = ac_coords_3d - output_dict["pc_center_upright"] = pc_coords_3d - output_dict["num_slices"] = slices_to_analyze + additional_metrics["ac_center"] = ac_coords_orig + additional_metrics["pc_center"] = pc_coords_orig + additional_metrics["ac_center_oriented_volume"] = ac_coords_standardized + additional_metrics["pc_center_oriented_volume"] = pc_coords_standardized + additional_metrics["ac_center_upright"] = ac_coords_3d + additional_metrics["pc_center_upright"] = pc_coords_3d + additional_metrics["slices_in_segmentation"] = slices_to_analyze + additional_metrics["voxel_size"] = [float(x) for x in orig.header.get_zooms()] + additional_metrics["num_thickness_points"] = num_thickness_points + additional_metrics["subdivision_method"] = subdivision_method + additional_metrics["subdivision_ratios"] = subdivisions + additional_metrics["contour_smoothing"] = contour_smoothing + additional_metrics["slice_selection"] = slice_selection # Convert numpy arrays to lists for JSON serialization - output_dict = convert_numpy_to_json_serializable(output_dict) + output_metrics_middle_slice = convert_numpy_to_json_serializable(output_metrics_middle_slice | additional_metrics) logger.info(f"Saving CC markers to {cc_markers_path}") with open(cc_markers_path, "w") as f: - json.dump(output_dict, f, indent=4) + json.dump(output_metrics_middle_slice, f, indent=4) + + + per_slice_output_dict = convert_numpy_to_json_serializable(per_slice_output_dict | additional_metrics) + + # Save slice-wise postprocessing results to JSON + with open(postproc_results_path, "w") as f: + json.dump(per_slice_output_dict, f, indent=4) + + if verbose: + logger.info(f"Multiple slice post-processing results saved to {postproc_results_path}") # save lta to fsaverage space logger.info(f"Saving LTA to fsaverage space: {upright_lta_path}") diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/cc_postprocessing.py index e542e700..a04caf54 100644 --- a/CorpusCallosum/shape/cc_postprocessing.py +++ b/CorpusCallosum/shape/cc_postprocessing.py @@ -54,7 +54,7 @@ def create_visualization(subdivision_method: str, result: dict, midslices_data: subdivision_method : str The subdivision method being used. result : dict - Dictionary containing processing results with split_contours and split_contours_hofer_frahm. + Dictionary containing processing results with split_contours. midslices_data : np.ndarray Slice data for visualization. output_image_path : str or Path @@ -76,10 +76,9 @@ def create_visualization(subdivision_method: str, result: dict, midslices_data: title = f'CC Subsegmentation by {subdivision_method} {title_suffix}' args_dict = { - 'debug': False, + 'debug': True, 'transformed': midslices_data, - 'split_contours': None, - 'split_contours_hofer_frahm': None, + 'split_contours': result['split_contours'], 'midline_equidistant': result['midline_equidistant'], 'levelpaths': result['levelpaths'], 'output_path': output_image_path, @@ -89,11 +88,6 @@ def create_visualization(subdivision_method: str, result: dict, midslices_data: 'title': title, } - if subdivision_method == "shape": - args_dict['split_contours'] = result['split_contours'] - else: - args_dict['split_contours_hofer_frahm'] = result['split_contours_hofer_frahm'] - return run_in_background(plot_contours, **args_dict) @@ -171,7 +165,6 @@ def process_slice( - total_area : float - Total area of the CC. - total_perimeter : float - Total perimeter length. - split_contours : list[np.ndarray] - Subdivided contour segments. - - split_contours_hofer_frahm : list[np.ndarray] - Alternative subdivision (if applicable). - midline_equidistant : np.ndarray - Equidistant points along midline. - levelpaths : list[np.ndarray] - Paths for thickness measurements. - thickness_measurement_points : np.ndarray - Points where thickness was measured. @@ -230,17 +223,14 @@ def process_slice( contour_1mm[:,anterior_endpoint_idx], contour_1mm[:,posterior_endpoint_idx])[0] for split_contour in split_contours] - split_contours_hofer_frahm = None elif subdivision_method == "vertical": areas, split_contours = subdivide_contour(contour_acpc, subdivisions, plot=False) - split_contours_hofer_frahm = split_contours.copy() elif subdivision_method == "angular": if not np.allclose(np.diff(subdivisions), np.diff(subdivisions)[0]): logger.error('Error: Angular subdivision method (Hampel) only supports equidistant subdivision, ' f'but got: {subdivisions}') return None areas, split_contours = hampel_subdivide_contour(contour_acpc, num_rays=len(subdivisions), plot=False) - split_contours_hofer_frahm = split_contours.copy() elif subdivision_method == "eigenvector": pt0, pt1 = get_primary_eigenvector(contour_acpc) contour_eigen, _, _, rotate_back_eigen = transform_to_acpc_standard(contour_acpc, pt0, pt1) @@ -248,7 +238,6 @@ def process_slice( ac_pt_eigen = ac_pt_eigen[:, 0] areas, split_contours = subdivide_contour(contour_eigen, subdivisions, oriented=True, hline_anchor=ac_pt_eigen) split_contours = [rotate_back_eigen(split_contour) for split_contour in split_contours] - split_contours_hofer_frahm = split_contours.copy() total_area = np.sum(areas) total_perimeter = np.sum(np.sqrt(np.sum((np.diff(contour_1mm, axis=0))**2, axis=1))) @@ -256,8 +245,6 @@ def process_slice( # Transform split contours back to original space split_contours = [rotate_back_acpc(split_contour) for split_contour in split_contours] - if split_contours_hofer_frahm is not None: - split_contours_hofer_frahm = [rotate_back_acpc(split_contour) for split_contour in split_contours_hofer_frahm] return { 'cc_index': cc_index, @@ -270,7 +257,6 @@ def process_slice( 'total_area': total_area, 'total_perimeter': total_perimeter, 'split_contours': split_contours, - 'split_contours_hofer_frahm': split_contours_hofer_frahm, 'midline_equidistant': midline_equidistant, 'levelpaths': levelpaths, 'slice_index': slice_idx @@ -466,13 +452,14 @@ def process_slices( cc_mesh.fill_thickness_values() cc_mesh.create_mesh() cc_mesh.smooth_(1) - cc_mesh.plot_mesh(output_path=cc_html_path) + if verbose: + logger.info(f"Saving CC 3D visualization to {cc_html_path}") + cc_mesh.plot_mesh(output_path=str(cc_html_path), show_mesh_edges=True) if vtk_file_path is not None: if verbose: logger.info(f"Saving vtk file to {vtk_file_path}") cc_mesh.write_vtk(str(vtk_file_path)) - #cc_mesh.write_vtk(str(output_dir / 'cc_mesh.vtk')) cc_mesh.to_fs_coordinates(vox2ras_tkr=vox2ras_tkr, vox_size=vox_size) diff --git a/CorpusCallosum/visualization/visualization.py b/CorpusCallosum/visualization/visualization.py index a0f93e12..99027880 100644 --- a/CorpusCallosum/visualization/visualization.py +++ b/CorpusCallosum/visualization/visualization.py @@ -73,7 +73,7 @@ def visualize_coordinate_spaces( pc_coords_3d: np.ndarray, ac_coords_standardized: np.ndarray, pc_coords_standardized: np.ndarray, - output_dir: str | Path, + output_plot_path: str | Path, ) -> None: """Visualize the AC and PC coordinates in different coordinate spaces. @@ -100,7 +100,7 @@ def visualize_coordinate_spaces( AC coordinates in standardized space pc_coords_standardized : np.ndarray PC coordinates in standardized space - output_dir : str or Path + output_plot_path : str or Path Directory to save visualization Notes @@ -126,7 +126,7 @@ def visualize_coordinate_spaces( a.set_aspect("equal", adjustable="box") a.axis("off") - plt.savefig(Path(output_dir) / "ac_pc_spaces.png", dpi=300, bbox_inches="tight") + plt.savefig(output_plot_path, dpi=300, bbox_inches="tight") plt.show() plt.close() @@ -134,7 +134,6 @@ def visualize_coordinate_spaces( def plot_contours( transformed: np.ndarray, split_contours: list[np.ndarray] | None = None, - split_contours_hofer_frahm: list[np.ndarray] | None = None, midline_equidistant: np.ndarray | None = None, levelpaths: list[np.ndarray] | None = None, output_path: str | Path | None = None, @@ -142,7 +141,6 @@ def plot_contours( pc_coords: np.ndarray | None = None, vox_size: float | None = None, title: str = "", - debug: bool = False, ) -> None: """Plot contours and subdivisions of the corpus callosum. @@ -152,8 +150,6 @@ def plot_contours( Transformed image data split_contours : list[np.ndarray], optional List of contour arrays for each subdivision, by default None - split_contours_hofer_frahm : list[np.ndarray], optional - List of contour arrays using Hofer-Frahm subdivision, by default None midline_equidistant : np.ndarray, optional Midline points at equidistant spacing, by default None levelpaths : list[np.ndarray], optional @@ -181,69 +177,43 @@ def plot_contours( split_contours = ( [split_contour / vox_size for split_contour in split_contours] if split_contours is not None else None ) - split_contours_hofer_frahm = ( - [split_contour / vox_size for split_contour in split_contours_hofer_frahm] - if split_contours_hofer_frahm is not None - else None - ) - midline_equidistant = midline_equidistant / vox_size - levelpaths = [levelpath / vox_size for levelpath in levelpaths] + midline_equidistant = midline_equidistant / vox_size if midline_equidistant is not None else None + levelpaths = [levelpath / vox_size for levelpath in levelpaths] if levelpaths is not None else None NO_PLOTS = 1 if split_contours is not None: NO_PLOTS += 1 - if split_contours_hofer_frahm is not None: - NO_PLOTS += 1 _, ax = plt.subplots(1, NO_PLOTS, sharex=True, sharey=True, figsize=(15, 10)) - PLT_NUM = 0 + # NOTE: For all plots imshow shows y inverted + current_plot = 0 + if split_contours is not None: - ax[PLT_NUM].imshow(transformed[transformed.shape[0] // 2], cmap="gray") + ax[current_plot].imshow(transformed[transformed.shape[0] // 2], cmap="gray") # ax[0].imshow(cc_mask, cmap='autumn') - ax[PLT_NUM].set_title(title) + ax[current_plot].set_title(title) for i in range(len(split_contours)): - ax[PLT_NUM].fill(split_contours[i][0, :], -split_contours[i][1, :], color="steelblue", alpha=0.25) - ax[PLT_NUM].plot( + ax[current_plot].fill(split_contours[i][0, :], -split_contours[i][1, :], color="steelblue", alpha=0.25) + ax[current_plot].plot( split_contours[i][0, :], -split_contours[i][1, :], color="mediumblue", linestyle="dotted", linewidth=0.7 ) - ax[PLT_NUM].plot(split_contours[0][0, :], -split_contours[0][1, :], color="mediumblue", linewidth=0.7) - ax[PLT_NUM].scatter(ac_coords[1], ac_coords[0], color="red", marker="x") - ax[PLT_NUM].scatter(pc_coords[1], pc_coords[0], color="blue", marker="x") - PLT_NUM += 1 - - if split_contours_hofer_frahm is not None: - ax[PLT_NUM].imshow(transformed[transformed.shape[0] // 2], cmap="gray") - # ax[1].imshow(cc_mask, cmap='autumn') - ax[PLT_NUM].set_title("Hofer-Frahm Jaenecke") - for i in range(len(split_contours_hofer_frahm)): - ax[PLT_NUM].fill( - split_contours_hofer_frahm[i][0, :], -split_contours_hofer_frahm[i][1, :], color="steelblue", alpha=0.25 - ) - ax[PLT_NUM].plot( - [split_contours_hofer_frahm[i][0, 0], split_contours_hofer_frahm[i][0, -1]], - [-split_contours_hofer_frahm[i][1, 0], -split_contours_hofer_frahm[i][1, -1]], - color="mediumblue", - linestyle="dotted", - linewidth=0.7, - ) - ax[PLT_NUM].plot( - split_contours_hofer_frahm[0][0, :], -split_contours_hofer_frahm[0][1, :], color="mediumblue", linewidth=0.7 - ) - ax[PLT_NUM].scatter(ac_coords[1], ac_coords[0], color="red", marker="x") - ax[PLT_NUM].scatter(pc_coords[1], pc_coords[0], color="blue", marker="x") - PLT_NUM += 1 - reference_contour = split_contours[0] if split_contours is not None else split_contours_hofer_frahm[0] + ax[current_plot].plot(split_contours[0][0, :], -split_contours[0][1, :], color="mediumblue", linewidth=0.7) + ax[current_plot].scatter(ac_coords[1], ac_coords[0], color="red", marker="x") + ax[current_plot].scatter(pc_coords[1], pc_coords[0], color="blue", marker="x") + current_plot += 1 + + reference_contour = split_contours[0] - ax[PLT_NUM].imshow(transformed[transformed.shape[0] // 2], cmap="gray") + ax[current_plot].imshow(transformed[transformed.shape[0] // 2], cmap="gray") # ax[2].imshow(cc_mask, cmap='autumn') for i in range(len(levelpaths)): - ax[PLT_NUM].plot(levelpaths[i][:, 0], -levelpaths[i][:, 1], color="brown", linewidth=0.8) - ax[PLT_NUM].set_title("Midline & Levelpaths") - ax[PLT_NUM].plot(midline_equidistant[:, 0], -midline_equidistant[:, 1], color="red") - ax[PLT_NUM].plot(reference_contour[0, :], -reference_contour[1, :], color="red", linewidth=0.5) + ax[current_plot].plot(levelpaths[i][:, 0], -levelpaths[i][:, 1], color="brown", linewidth=0.8) + ax[current_plot].set_title("Midline & Levelpaths") + ax[current_plot].plot(midline_equidistant[:, 0], -midline_equidistant[:, 1], color="red") + ax[current_plot].plot(reference_contour[0, :], -reference_contour[1, :], color="red", linewidth=0.5) for a in ax.flatten(): a.set_aspect("equal", adjustable="box") From 4068e493f32b570ebc9a83a0bd4cb40ce61918ea Mon Sep 17 00:00:00 2001 From: ClePol Date: Wed, 12 Nov 2025 10:23:45 +0100 Subject: [PATCH 22/68] added consolidation strategy with WM and ventricle labels --- CorpusCallosum/paint_cc_into_pred.py | 198 ++++++++++++++++++++++++++- 1 file changed, 191 insertions(+), 7 deletions(-) diff --git a/CorpusCallosum/paint_cc_into_pred.py b/CorpusCallosum/paint_cc_into_pred.py index 97cdf74e..b5d5a635 100644 --- a/CorpusCallosum/paint_cc_into_pred.py +++ b/CorpusCallosum/paint_cc_into_pred.py @@ -21,6 +21,9 @@ import nibabel as nib import numpy as np from numpy import typing as npt +from scipy import ndimage + +from FastSurferCNN.data_loader.conform import is_conform HELPTEXT = """ Script to add corpus callosum segmentation (CC, FreeSurfer IDs 251-255) to @@ -74,7 +77,8 @@ def argument_parse(): return args -def paint_in_cc(pred: npt.NDArray[np.int_], aseg_cc: npt.NDArray[np.int_]) -> npt.NDArray[np.int_]: +def paint_in_cc(pred: npt.NDArray[np.int_], + aseg_cc: npt.NDArray[np.int_]) -> npt.NDArray[np.int_]: """Paint corpus callosum segmentation into aseg+dkt segmentation map. Parameters @@ -98,21 +102,201 @@ def paint_in_cc(pred: npt.NDArray[np.int_], aseg_cc: npt.NDArray[np.int_]) -> np pred[cc_mask] = aseg_cc[cc_mask] return pred +def correct_wm_ventricles( + aseg_cc: npt.NDArray[np.int_], + fornix_mask: npt.NDArray[np.bool_], + voxel_size: tuple[float, float, float], + close_gap_size_mm: float = 3.0 +) -> npt.NDArray[np.int_]: + """Correct WM mask and ventricle labels according to the CC and fornix masks. + + The function + Take non-CC-connected WM components -> remove + Take FN -> WM + Fill space in superior inferior direction between CC and left/right Ventricle with corresponding Ventricle labels + """ + + # Create a copy to avoid modifying the original + corrected_pred = aseg_cc.copy() + + # Get CC mask (labels 251-255) + cc_mask = (aseg_cc >= 251) & (aseg_cc <= 255) + + # Get left and right ventricle masks + all_ventricle_mask = (aseg_cc == 4) | (aseg_cc == 43) + + # Combine all WM labels + all_wm_mask = (aseg_cc == 2) | (aseg_cc == 41) + + + # 1. Fill space between CC and ventricles + # Only fill small gaps (up to 3 voxels) between CC and ventricle boundaries + #for ventricle_label, ventricle_mask in [(4, left_ventricle_mask), (43, right_ventricle_mask)]: + + # Process each slice independently + for x in range(corrected_pred.shape[0]): + cc_slice = cc_mask[x, :, :] + #vent_slice = ventricle_mask[x, :, :] + all_wm_slice = all_wm_mask[x, :, :] + + + if all_wm_slice.any() and cc_slice.any(): + + # Dilate CC mask to find adjacent voxels, then check for overlap with component + cc_dilated = ndimage.binary_dilation(cc_slice, iterations=1) + # Label connected components in WM + labeled_wm, num_components = ndimage.label(all_wm_slice) + + # Find components that are adjacent to CC and remove them + for label in range(1, num_components + 1): + component_mask = labeled_wm == label + # Check if this component is adjacent to (touches) the CC + if np.any(component_mask & cc_dilated): + corrected_pred[x, :, :][component_mask] = 0 # Set to background + + + if fornix_mask[x, :, :].any(): + fornix_slice = fornix_mask[x, :, :] + # count WM labels overlapping with fornix + left_wm_overlap = np.sum(fornix_slice & (aseg_cc == 2)) + right_wm_overlap = np.sum(fornix_slice & (aseg_cc == 41)) + if left_wm_overlap > right_wm_overlap: + corrected_pred[x, :, :][fornix_slice] = 2 # Left WM + else: + corrected_pred[x, :, :][fornix_slice] = 41 # Right WM + + + vent_slice = all_ventricle_mask[x, :, :] + + if cc_slice.any() and vent_slice.any(): + # Create binary masks for this slice + cc_binary = cc_slice.astype(bool) + vent_binary = vent_slice.astype(bool) + + # Dilate both masks slightly to find potential connection points + max_gap_vox = int(np.ceil(voxel_size[1] * close_gap_size_mm)) + cc_dilated = ndimage.binary_dilation(cc_binary, iterations=max_gap_vox) + vent_dilated = ndimage.binary_dilation(vent_binary, iterations=max_gap_vox) + + # Find voxels that are adjacent to both CC and ventricle but not part of either + potential_fill = (cc_dilated & vent_dilated) & ~(cc_binary | vent_binary) + + # Only fill small gaps between CC and ventricle in inferior-superior direction + if potential_fill.any(): + + + for z in range(potential_fill.shape[1]): + potential_fill_line = potential_fill[:, z] + labeled_gaps, num_gaps = ndimage.label(potential_fill_line) + cc_line = cc_binary[:, z] + vent_line = vent_binary[:, z] + + + + + + for gap_label in range(1, num_gaps + 1): + gap_mask = labeled_gaps == gap_label + + # check that CC and ventricle are connected to the gap_mask + dilated_gap_mask = ndimage.binary_dilation(gap_mask, iterations=1) + if not np.any(cc_line & dilated_gap_mask): + continue + if not np.any(vent_line & dilated_gap_mask): + continue + + vent_label_location = np.where(vent_line & dilated_gap_mask)[0] + vent_label = corrected_pred[x, vent_label_location, z] + + + + + + if np.sum(gap_mask) > max_gap_vox: + continue + + corrected_pred[x, :, z][gap_mask & (corrected_pred[x, :, z] == 0)] = vent_label + + # Process gaps in z-direction (within each y-row) + for y in range(potential_fill.shape[0]): + potential_fill_line = potential_fill[y, :] + labeled_gaps, num_gaps = ndimage.label(potential_fill_line) + cc_line = cc_binary[y, :] + vent_line = vent_binary[y, :] + + for gap_label in range(1, num_gaps + 1): + gap_mask = labeled_gaps == gap_label + + # check that CC and ventricle are connected to the gap_mask + dilated_gap_mask = ndimage.binary_dilation(gap_mask, iterations=1) + if not np.any(cc_line & dilated_gap_mask): + continue + if not np.any(vent_line & dilated_gap_mask): + continue + + vent_label_location = np.where(vent_line & dilated_gap_mask)[0] + if len(vent_label_location) > 0: + vent_label = corrected_pred[x, y, vent_label_location[0]] # Take first match + + if np.sum(gap_mask) > max_gap_vox: + continue + + corrected_pred[x, y, :][gap_mask & (corrected_pred[x, y, :] == 0)] = vent_label + + + return corrected_pred + if __name__ == "__main__": # Command Line options are error checking done here options = argument_parse() + + print(f"Reading inputs: {options.input_cc} {options.input_pred}...") - aseg_image = np.asanyarray(nib.load(options.input_cc).dataobj) - prediction = nib.load(options.input_pred) - pred_with_cc = paint_in_cc(np.asanyarray(prediction.dataobj), aseg_image) + cc_seg_image = nib.load(options.input_cc) + cc_seg_data = np.asanyarray(cc_seg_image.dataobj) + aseg_image = nib.load(options.input_pred) + aseg_data = np.asanyarray(aseg_image.dataobj) + + cc_conformed = is_conform(cc_seg_image, vox_size=None, img_size=None, verbose=False) + pred_conformed = is_conform(aseg_image, vox_size=None, img_size=None, verbose=False) + if not cc_conformed: + print("Warning: CC input image is not conformed (LIA orientation, uint8 dtype). \ + Please conform the image using the conform.py script.") + if not pred_conformed: + print("Warning: Prediction input image is not conformed (LIA orientation, uint8 dtype). \ + Please conform the image using the conform.py script.") + + # Count initial labels + initial_cc = np.sum((aseg_data >= 251) & (aseg_data <= 255)) + initial_fornix = np.sum(aseg_data == 250) + initial_wm = np.sum((aseg_data == 2) | (aseg_data == 41)) + print(f"Initial segmentation: CC={initial_cc}, Fornix={initial_fornix}, WM={initial_wm}") + + # Paint CC into prediction + pred_with_cc = paint_in_cc(aseg_data, cc_seg_data) + after_paint_cc = np.sum((pred_with_cc >= 251) & (pred_with_cc <= 255)) + print(f"After painting CC: {after_paint_cc} CC voxels added") + + # Apply WM and ventricle corrections + print("Applying white matter and ventricle corrections...") + fornix_mask = cc_seg_data == 250 + voxel_size = tuple(aseg_image.header.get_zooms()) + pred_corrected = correct_wm_ventricles(aseg_data, fornix_mask, voxel_size) + + # Count final labels + final_cc = np.sum((pred_corrected >= 251) & (pred_corrected <= 255)) + final_fornix = np.sum(pred_corrected == 250) + final_wm = np.sum((pred_corrected == 2) | (pred_corrected == 41)) + final_ventricles = np.sum((pred_corrected == 4) | (pred_corrected == 43)) + + print(f"Final segmentation: CC={final_cc}, Fornix={final_fornix}, WM={final_wm}, Ventricles={final_ventricles}") + print(f"Changes: CC +{final_cc-initial_cc}, Fornix {final_fornix-initial_fornix}, WM {final_wm-initial_wm}") print(f"Writing segmentation with corpus callosum to: {options.output}") - pred_with_cc_fin = nib.MGHImage(pred_with_cc, prediction.affine, prediction.header) + pred_with_cc_fin = nib.MGHImage(pred_corrected, aseg_image.affine, aseg_image.header) pred_with_cc_fin.to_filename(options.output) sys.exit(0) - -# TODO: Rename the file (paint_cc_into_asegdkt or similar) and functions. From 1b8ecc92c81d28772737e9a29b628055d6da97cd Mon Sep 17 00:00:00 2001 From: ClePol Date: Wed, 12 Nov 2025 14:15:44 +0100 Subject: [PATCH 23/68] FastSurfer style weights loading + removed superflous plot --- CorpusCallosum/data/constants.py | 3 +- CorpusCallosum/fastsurfer_cc.py | 33 +++---------------- .../localization/localization_inference.py | 6 ++-- .../segmentation/segmentation_inference.py | 5 +-- 4 files changed, 12 insertions(+), 35 deletions(-) diff --git a/CorpusCallosum/data/constants.py b/CorpusCallosum/data/constants.py index ba2ef46d..e0789cf0 100644 --- a/CorpusCallosum/data/constants.py +++ b/CorpusCallosum/data/constants.py @@ -16,13 +16,14 @@ from pathlib import Path ### Constants -WEIGHTS_PATH = Path(__file__).parent.parent / "weights" +WEIGHTS_PATH = Path(__file__).parent.parent.parent / "checkpoints" FSAVERAGE_CENTROIDS_PATH = Path(__file__).parent / "fsaverage_centroids.json" FSAVERAGE_DATA_PATH = Path(__file__).parent / "fsaverage_data.json" # Contains both affine and header FSAVERAGE_MIDDLE = 128 # Middle slice index in fsaverage space CC_LABEL = 192 # Label value for corpus callosum in segmentation FORNIX_LABEL = 250 # Label value for fornix in segmentation SUBSEGEMNT_LABELS = [251, 252, 253, 254, 255] # labels for subsegments in segmentation +FASTSURFER_ROOT = Path(__file__).parent.parent.parent # TODO: use FastSurfer function for this STANDARD_OUTPUT_PATHS = { diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index d5b02f26..2434822d 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -14,7 +14,6 @@ FSAVERAGE_DATA_PATH, FSAVERAGE_MIDDLE, STANDARD_OUTPUT_PATHS, - WEIGHTS_PATH, ) from CorpusCallosum.data.read_write import ( convert_numpy_to_json_serializable, @@ -603,14 +602,9 @@ def main( device = torch.device("cuda" if torch.cuda.is_available() and not cpu else "cpu") logger.info(f"Using device: {device}") - logger.info("Loading localization model") - model_localization = localization_inference.load_model( - str(Path(WEIGHTS_PATH) / "localization_weights_acpc.pth"), device=device - ) - logger.info("Loading segmentation model") - model_segmentation = segmentation_inference.load_model( - str(Path(WEIGHTS_PATH) / "segmentation_weights_cc_fn.pth"), device=device - ) + logger.info("Loading models") + model_localization = localization_inference.load_model(device=device) + model_segmentation = segmentation_inference.load_model(device=device) aseg_nib = nib.load(aseg_path) @@ -700,6 +694,7 @@ def main( verbose=verbose, save_template=save_template, ) + IO_processes.extend(slice_io_processes) outer_contours = [slice_result['split_contours'][0] for slice_result in slice_results] @@ -708,8 +703,6 @@ def main( logger.warning("Large area changes detected between consecutive slices, " "this is likely due to a segmentation error.") - IO_processes.extend(slice_io_processes) - # Get middle slice result for backward compatibility middle_slice_result = slice_results[len(slice_results) // 2] @@ -737,24 +730,6 @@ def main( ) ) - # Save middle slice visualization - IO_processes.append( - create_visualization( - subdivision_method, - { - "split_contours": middle_slice_result["split_contours"], - "midline_equidistant": middle_slice_result["midline_equidistant"], - "levelpaths": middle_slice_result["levelpaths"], - }, - midslices, - output_dir, - ac_coords, - pc_coords, - orig.header.get_zooms()[0], - " (Middle Slice)", - ) - ) - save_nifti_background(IO_processes, segmentation, seg_affine, orig.header, segmentation_path) diff --git a/CorpusCallosum/localization/localization_inference.py b/CorpusCallosum/localization/localization_inference.py index 6527acdc..df108ace 100644 --- a/CorpusCallosum/localization/localization_inference.py +++ b/CorpusCallosum/localization/localization_inference.py @@ -19,14 +19,14 @@ from monai import transforms from monai.networks.nets import DenseNet +from CorpusCallosum.data import constants from CorpusCallosum.transforms.localization_transforms import CropAroundACPCFixedSize from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults from FastSurferCNN.download_checkpoints import main as download_checkpoints -def load_model(checkpoint_path: str | Path | None = None, - device: torch.device | None = None) -> DenseNet: +def load_model(device: torch.device | None = None) -> DenseNet: """Load trained numerical localization model from checkpoint. Parameters @@ -65,7 +65,7 @@ def load_model(checkpoint_path: str | Path | None = None, "checkpoint", filename=CC_YAML, ) - checkpoint_path = cc_config['localization'] + checkpoint_path = constants.FASTSURFER_ROOT / cc_config['localization'] # Load state dict if isinstance(checkpoint_path, str) or isinstance(checkpoint_path, Path): diff --git a/CorpusCallosum/segmentation/segmentation_inference.py b/CorpusCallosum/segmentation/segmentation_inference.py index a013fdc7..b0b6c0b6 100644 --- a/CorpusCallosum/segmentation/segmentation_inference.py +++ b/CorpusCallosum/segmentation/segmentation_inference.py @@ -17,6 +17,7 @@ import torch from monai import transforms +from CorpusCallosum.data import constants from CorpusCallosum.transforms.segmentation_transforms import CropAroundACPC from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults @@ -24,7 +25,7 @@ from FastSurferCNN.models.networks import FastSurferVINN -def load_model(checkpoint_path: str | None = None, device: torch.device | None = None) -> FastSurferVINN: +def load_model(device: torch.device | None = None) -> FastSurferVINN: """Load trained model from checkpoint. Parameters @@ -70,7 +71,7 @@ def load_model(checkpoint_path: str | None = None, device: torch.device | None = "checkpoint", filename=CC_YAML, ) - checkpoint_path = cc_config['segmentation'] + checkpoint_path = constants.FASTSURFER_ROOT / cc_config['segmentation'] weights = torch.load(checkpoint_path, weights_only=True, map_location=device) model.load_state_dict(weights) From 257068d9892ce516c59d851b6ac303ad72b7e1f5 Mon Sep 17 00:00:00 2001 From: ClePol Date: Wed, 12 Nov 2025 14:18:16 +0100 Subject: [PATCH 24/68] recon-surf integration with FastSurferCNN label consolidation --- CorpusCallosum/paint_cc_into_pred.py | 10 ---------- recon_surf/recon-surf.sh | 26 +++++++++++++++++++++++--- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/CorpusCallosum/paint_cc_into_pred.py b/CorpusCallosum/paint_cc_into_pred.py index b5d5a635..7815f9a1 100644 --- a/CorpusCallosum/paint_cc_into_pred.py +++ b/CorpusCallosum/paint_cc_into_pred.py @@ -183,18 +183,12 @@ def correct_wm_ventricles( # Only fill small gaps between CC and ventricle in inferior-superior direction if potential_fill.any(): - - for z in range(potential_fill.shape[1]): potential_fill_line = potential_fill[:, z] labeled_gaps, num_gaps = ndimage.label(potential_fill_line) cc_line = cc_binary[:, z] vent_line = vent_binary[:, z] - - - - for gap_label in range(1, num_gaps + 1): gap_mask = labeled_gaps == gap_label @@ -208,10 +202,6 @@ def correct_wm_ventricles( vent_label_location = np.where(vent_line & dilated_gap_mask)[0] vent_label = corrected_pred[x, vent_label_location, z] - - - - if np.sum(gap_mask) > max_gap_vox: continue diff --git a/recon_surf/recon-surf.sh b/recon_surf/recon-surf.sh index d415fd7b..f0199664 100755 --- a/recon_surf/recon-surf.sh +++ b/recon_surf/recon-surf.sh @@ -619,6 +619,8 @@ fi # ============================= CC SEGMENTATION ============================================ + + { echo " " echo "============ Creating and adding CC Segmentation ============" @@ -627,11 +629,29 @@ fi # create aseg.auto including corpus callosum segmentation and 46 sec, requires norm.mgz # Note: if original input segmentation already contains CC, this will exit with ERROR # in the future maybe check and skip this step (and next) -#cmd="mri_cc -aseg $aseg_nocc -o aseg.auto.mgz -lta $mdir/transforms/cc_up.lta $subject" -#RunIt "$cmd" "$LF" +cmd="$python ${binpath}../CorpusCallosum/fastsurfer_cc.py --subject_dir $SUBJECTS_DIR/$subject --verbose" +RunIt "$cmd" "$LF" # add CC into aparc.DKTatlas+aseg.deep (not sure if this is really needed) -cmd="$python ${binpath}/../CorpusCallosum/paint_cc_into_pred.py -in_cc $mdir/aseg.auto.mgz -in_pred $asegdkt_segfile -out $mdir/aparc.DKTatlas+aseg.deep.withCC.mgz" +cmd="$python ${FASTSURFER_HOME}/CorpusCallosum/paint_cc_into_pred.py -in_cc $mdir/callosum_seg_aseg_space.mgz -in_pred $asegdkt_segfile -out $mdir/aparc.DKTatlas+aseg.deep.withCC.mgz" RunIt "$cmd" "$LF" +# add CC into aseg.auto.mgz as mri_cc did before. Not sure where this is used. +cmd="$python ${FASTSURFER_HOME}/CorpusCallosum/paint_cc_into_pred.py -in_cc $mdir/callosum_seg_aseg_space.mgz -in_pred $mdir/$aseg_nocc -out $mdir/aseg.auto.mgz" +RunIt "$cmd" "$LF" + + +# { +# echo " " +# echo "============ Creating and adding CC Segmentation (mri_cc) ============" +# echo " " +# } | tee -a "$LF" +# # create aseg.auto including corpus callosum segmentation and 46 sec, requires norm.mgz +# # Note: if original input segmentation already contains CC, this will exit with ERROR +# # in the future maybe check and skip this step (and next) +# cmd="mri_cc -aseg $aseg_nocc -o aseg.auto.mgz -lta $mdir/transforms/cc_up.lta $subject" +# RunIt "$cmd" "$LF" +# # add CC into aparc.DKTatlas+aseg.deep (not sure if this is really needed) +# cmd="$python ${binpath}../CorpusCallosum/paint_cc_into_pred.py -in_cc $mdir/aseg.auto.mgz -in_pred $asegdkt_segfile -out $mdir/aparc.DKTatlas+aseg.deep.withCC.mgz" +# RunIt "$cmd" "$LF" # ============================= FILLED ===================================================== From 9a9ebacd5cb994afa7782d780ec1954f8401d35f Mon Sep 17 00:00:00 2001 From: ClePol Date: Wed, 12 Nov 2025 16:59:16 +0100 Subject: [PATCH 25/68] updated commandline interface --- CorpusCallosum/README.md | 2 +- CorpusCallosum/data/constants.py | 7 +- CorpusCallosum/fastsurfer_cc.py | 252 ++++++++++++++-------- CorpusCallosum/shape/cc_postprocessing.py | 34 +-- 4 files changed, 183 insertions(+), 112 deletions(-) diff --git a/CorpusCallosum/README.md b/CorpusCallosum/README.md index 3c47d282..1c2c3162 100644 --- a/CorpusCallosum/README.md +++ b/CorpusCallosum/README.md @@ -76,7 +76,7 @@ Choose one of these input methods: - `--upright_lta_path PATH`: Path for upright LTA transform - `--orient_volume_lta_path PATH`: Path for orientation volume LTA transform - `--orig_space_segmentation_path PATH`: Path for segmentation in original space -- `--debug_image_path PATH`: Path for debug visualization image +- `--qc_image_path PATH`: Path for QC visualization image **Template Saving:** - `--save_template PATH`: Directory path to save contours.txt and thickness_values.txt files diff --git a/CorpusCallosum/data/constants.py b/CorpusCallosum/data/constants.py index e0789cf0..dd20a8c5 100644 --- a/CorpusCallosum/data/constants.py +++ b/CorpusCallosum/data/constants.py @@ -26,6 +26,11 @@ FASTSURFER_ROOT = Path(__file__).parent.parent.parent # TODO: use FastSurfer function for this +STANDARD_INPUT_PATHS = { + "t1": "mri/orig.mgz", + "aseg_name": "mri/aparc.DKTatlas+aseg.deep.mgz", +} + STANDARD_OUTPUT_PATHS = { ## images "upright_volume": None, # orig.mgz mapped to upright space @@ -42,7 +47,7 @@ "upright_lta": "mri/transforms/cc_up.lta", # lta transform from orig to upright space "orient_volume_lta": "mri/transforms/orient_volume.lta", # lta transform from orig to upright+acpc corrected space ## qc - "debug_image": "qc_snapshots/callosum.png", # debug image of cc contours + "qc_image": "qc_snapshots/callosum.png", # debug image of cc contours "thickness_image": "qc_snapshots/callosum_thickness.png", # whippersnappy 3D image of cc thickness "cc_html": "qc_snapshots/corpus_callosum.html", # plotly cc visualization ## surface diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 2434822d..ab705a52 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -13,6 +13,7 @@ FSAVERAGE_CENTROIDS_PATH, FSAVERAGE_DATA_PATH, FSAVERAGE_MIDDLE, + STANDARD_INPUT_PATHS, STANDARD_OUTPUT_PATHS, ) from CorpusCallosum.data.read_write import ( @@ -34,7 +35,6 @@ from CorpusCallosum.segmentation import segmentation_inference, segmentation_postprocessing from CorpusCallosum.shape.cc_postprocessing import ( check_area_changes, - create_visualization, make_subdivision_mask, process_slices, ) @@ -48,35 +48,44 @@ def options_parse() -> argparse.Namespace: """Parse command line arguments for the pipeline.""" parser = argparse.ArgumentParser() + + # Specify subject directory + subject ID, OR specify individual MRI and segmentation files + output paths + mgroup = parser.add_mutually_exclusive_group() + mgroup.add_argument( + "--sd", + type=Path, + help="Root directory in which the case directory is located. " + "Must be used together with --sid.", + ) parser.add_argument( - "--in_mri", + "--sid", type=str, - required=False, - help="Input MRI file path. If not provided, defaults to subject_dir/mri/orig.mgz", + help="Name of the case directory. Must be used together with --sd.", ) - parser.add_argument( - "--cpu", - action="store_true", - help="Force CPU usage even when CUDA is available", + mgroup.add_argument( + "--t1", + type=Path, + help=f"Input MRI file path. Must be used together with --aseg_name. \ + (default: subject_dir/{STANDARD_INPUT_PATHS['t1']})", ) parser.add_argument( - "--aseg", - type=str, - required=False, - help="Input segmentation file path. If not provided, defaults to subject_dir/mri/aparc.DKTatlas+aseg.deep.mgz", + "--aseg_name", + type=Path, + help=f"Input segmentation file path. Must be used together with --t1. \ + (default: subject_dir/{STANDARD_INPUT_PATHS['aseg_name']})", ) parser.add_argument( - "--subject_dir", + "--device", type=str, - required=False, - help="Subject directory containing standard FreeSurfer structure. " - "Required if --in_mri and --aseg are not both provided.", - default=None, + default="auto", + help="Select device to run inference on: cpu, or cuda (= Nvidia gpu) or specify a certain gpu (e.g. cuda:1), \ + Default: auto", ) - parser.add_argument("--debug_output_dir", type=str, required=False, default=None, - help="Directory for debug output (default: subject_dir/qc_snapshots)") + parser.add_argument( - "--num_thickness_points", type=int, default=100, help="Number of points for thickness estimation." + "--num_thickness_points", + type=int, default=100, + help="Number of points for thickness estimation." ) parser.add_argument( "--subdivisions", @@ -100,127 +109,176 @@ def options_parse() -> argparse.Namespace: "--contour_smoothing", type=float, default=5, - help="Window size for smoothing during contour detection. Default is 5, higher values mean a smoother" - "outline, at the cost of precision.", + help="Gaussian sigma for smoothing during contour detection. Higher values mean a smoother" + " CC outline, at the cost of precision. (default: 5)", ) parser.add_argument( "--slice_selection", type=str, default="all", - help="Which slices to process. Options: 'middle' (default), 'all', or a specific slice number.", + help="Which slices to process. Options: 'middle', 'all', or a specific slice number. \ + (default: 'all')", ) parser.add_argument( + "--verbose", + action="store_true", + help="Enable verbose (shows output paths)", + default=False, + ) + + ######## OUTPUT PATHS ######### + # 4. Options for advanced, technical parameters + advanced = parser.add_argument_group(title="Advanced options", + description="Custom output paths, useful if no standard case directory is used.") + advanced.add_argument("--qc_output_dir", + type=Path, + required=False, + default=None, + help="Directory for quality control output (default: subject_dir/qc_snapshots)") + advanced.add_argument( "--upright_volume_path", - type=str, - help=f"Path for upright volume output (default: subject_dir/{STANDARD_OUTPUT_PATHS['upright_volume']})", + type=Path, + help="Path for upright volume output (default: No output)", default=None, ) - parser.add_argument( + advanced.add_argument( "--segmentation_path", - type=str, + type=Path, help=f"Path for segmentation output (default: subject_dir/{STANDARD_OUTPUT_PATHS['segmentation']})", default=None, ) - parser.add_argument( + advanced.add_argument( "--postproc_results_path", - type=str, - help=f"Path for postprocessing results (default: subject_dir/{STANDARD_OUTPUT_PATHS['postproc_results']})", + type=Path, + help=f"Path for postprocessing results. Contains metrics describing CC shape and volume for each slice \ + (default: subject_dir/{STANDARD_OUTPUT_PATHS['postproc_results']})", default=None, ) - parser.add_argument( + advanced.add_argument( "--cc_markers_path", - type=str, - help=f"Path for CC markers output (default: subject_dir/{STANDARD_OUTPUT_PATHS['cc_markers']})", + type=Path, + help=f"Path for CC markers output. Contains metrics describing CC shape and volume \ + (default: subject_dir/{STANDARD_OUTPUT_PATHS['cc_markers']})", default=None, ) - parser.add_argument( + advanced.add_argument( "--upright_lta_path", - type=str, - help=f"Path for upright LTA transform (default: subject_dir/{STANDARD_OUTPUT_PATHS['upright_lta']})", + type=Path, + help=f"Path for upright LTA transform. This makes sure the midplane is at 128 in LR direction, but no nodding \ + correction is applied (default: subject_dir/{STANDARD_OUTPUT_PATHS['upright_lta']})", default=None, ) - parser.add_argument( + advanced.add_argument( "--orient_volume_lta_path", - type=str, - help="Path for orientation volume LTA transform " - f"(default: subject_dir/{STANDARD_OUTPUT_PATHS['orient_volume_lta']})", + type=Path, + help=f"Path for orientation volume LTA transform. This makes sure the midplane is at 128 in LR direction, \ + and the AC & PC are on the coordinate line, standardizing the head orientation. \ + (default: subject_dir/{STANDARD_OUTPUT_PATHS['orient_volume_lta']})", default=None, ) - parser.add_argument( + advanced.add_argument( "--orig_space_segmentation_path", - type=str, - help="Path for segmentation in original space " + type=Path, + help="Path for segmentation in the input MRI space " f"(default: subject_dir/{STANDARD_OUTPUT_PATHS['orig_space_segmentation']})", default=None, ) - parser.add_argument( - "--debug_image_path", - type=str, - help=f"Path for debug visualization image (default: subject_dir/{STANDARD_OUTPUT_PATHS['debug_image']})", + advanced.add_argument( + "--qc_image_path", + type=Path, + help=f"Path for QC visualization image (default: subject_dir/{STANDARD_OUTPUT_PATHS['qc_image']})", default=None, ) - parser.add_argument( + advanced.add_argument( "--save_template", - type=str, - help="Directory path where to save contours.txt and thickness_values.txt files", + type=Path, + help="Directory path where to save contours.txt and thickness_values.txt files. \ + These files can be used to visualize the CC shape and volume in 3D.", default=None, ) - parser.add_argument( + advanced.add_argument( "--thickness_image_path", - type=str, + type=Path, help=f"Path for thickness image (default: subject_dir/{STANDARD_OUTPUT_PATHS['thickness_image']})", default=None, ) - parser.add_argument( + advanced.add_argument( "--surf_file_path", - type=str, + type=Path, help=f"Path for surf file (default: subject_dir/{STANDARD_OUTPUT_PATHS['surf_file']})", default=None, ) - parser.add_argument( + advanced.add_argument( "--overlay_file_path", - type=str, + type=Path, help=f"Path for overlay file (default: subject_dir/{STANDARD_OUTPUT_PATHS['overlay_file']})", default=None, ) - parser.add_argument( + advanced.add_argument( "--cc_html_path", - type=str, - help=f"Path for CC HTML file (default: subject_dir/{STANDARD_OUTPUT_PATHS['cc_html']})", + type=Path, + help=f"Path to CC 3D visualization for CC HTML file (default: subject_dir/{STANDARD_OUTPUT_PATHS['cc_html']})", default=None, ) - parser.add_argument( + advanced.add_argument( "--vtk_file_path", - type=str, - help=f"Path for vtk file (default: subject_dir/{STANDARD_OUTPUT_PATHS['vtk_file']})", + type=Path, + help=f"Path for vtk file, showing the CC 3D mesh (default: subject_dir/{STANDARD_OUTPUT_PATHS['vtk_file']})", default=None, ) - parser.add_argument( + advanced.add_argument( "--softlabels_cc_path", - type=str, - help=f"Path for cc softlabels (default: subject_dir/{STANDARD_OUTPUT_PATHS['softlabels_cc']})", + type=Path, + help=f"Path for cc softlabels. Contains the probability of each voxel being part of the CC \ + (default: subject_dir/{STANDARD_OUTPUT_PATHS['softlabels_cc']})", default=None, ) - parser.add_argument( + advanced.add_argument( "--softlabels_fn_path", - type=str, - help=f"Path for fornix softlabels (default: subject_dir/{STANDARD_OUTPUT_PATHS['softlabels_fn']})", + type=Path, + help=f"Path for fornix softlabels. Contains the probability of each voxel being part of the Fornix \ + (default: subject_dir/{STANDARD_OUTPUT_PATHS['softlabels_fn']})", default=None, ) - parser.add_argument( + advanced.add_argument( "--softlabels_background_path", - type=str, - help=f"Path for background softlabels (default: subject_dir/{STANDARD_OUTPUT_PATHS['softlabels_background']})", + type=Path, + help=f"Path for background softlabels. Contains the probability of each voxel being part of the background \ + (default: subject_dir/{STANDARD_OUTPUT_PATHS['softlabels_background']})", default=None, ) - parser.add_argument("--verbose", action="store_true", help="Enable verbose (shows output paths)", default=False) + ############ END OF OUTPUT PATHS ############ + args = parser.parse_args() - # Validation logic: either subject_dir OR both in_mri and aseg must be provided - if not args.subject_dir and (not args.in_mri or not args.aseg): - parser.error("You must specify either --subject_dir OR both --in_mri and --aseg arguments.") + # Reconstruct subject_dir from sd and sid (but sd might be stored as out_dir by parser_defaults) + sd_value = getattr(args, 'sd', getattr(args, 'out_dir', None)) + if sd_value and hasattr(args, 'sid') and args.sid: + args.subject_dir = str(Path(sd_value) / args.sid) + else: + args.subject_dir = None + + # Validation logic: must use either directory approach (--sd + --sid) OR file approach (--t1 + --aseg_name) + if sd_value: + # Using directory approach - make sure sid was also provided + if not (hasattr(args, 'sid') and args.sid): + parser.error("When using --sd, you must also provide --sid.") + elif hasattr(args, 'sid') and args.sid: + # If sid is provided without sd, that's an error + if not sd_value: + parser.error("When using --sid, you must also provide --sd.") + elif hasattr(args, 't1') and args.t1: + # Using file approach - make sure aseg_name was also provided + if not (hasattr(args, 'aseg_name') and args.aseg_name): + parser.error("When using --t1, you must also provide --aseg_name.") + elif hasattr(args, 'aseg_name') and args.aseg_name: + # If aseg_name is provided without t1, that's an error + if not (hasattr(args, 't1') and args.t1): + parser.error("When using --aseg_name, you must also provide --t1.") + else: + parser.error("You must specify either --sd and --sid OR both --t1 and --aseg_name.") # If subject_dir is provided, set default paths for missing arguments if args.subject_dir: @@ -231,11 +289,11 @@ def options_parse() -> argparse.Namespace: (subject_dir_path / "stats").mkdir(parents=True, exist_ok=True) (subject_dir_path / "transforms").mkdir(parents=True, exist_ok=True) - if not args.in_mri: - args.in_mri = str(subject_dir_path / "mri" / "orig.mgz") + if not args.t1: + args.t1 = str(subject_dir_path / STANDARD_INPUT_PATHS["t1"]) - if not args.aseg: - args.aseg = str(subject_dir_path / "mri" / "aparc.DKTatlas+aseg.deep.mgz") + if not args.aseg_name: + args.aseg_name = str(subject_dir_path / STANDARD_INPUT_PATHS["aseg_name"]) # Set default output paths if not provided for key, value in STANDARD_OUTPUT_PATHS.items(): @@ -445,14 +503,14 @@ def main( aseg_path: str | Path, output_dir: str | Path, slice_selection: str = "middle", - debug_output_dir: str | Path = None, + qc_output_dir: str | Path = None, verbose: bool = False, num_thickness_points: int = 100, subdivisions: list[float] | None = None, subdivision_method: str = "shape", contour_smoothing: float = 5, save_template: str | Path | None = None, - cpu: bool = False, + device: str = "auto", upright_volume_path: str | Path = None, segmentation_path: str | Path = None, postproc_results_path: str | Path = None, @@ -464,7 +522,7 @@ def main( cc_html_path: str | Path = None, vtk_file_path: str | Path = None, orig_space_segmentation_path: str | Path = None, - debug_image_path: str | Path = None, + qc_image_path: str | Path = None, thickness_image_path: str | Path = None, softlabels_cc_path: str | Path = None, softlabels_fn_path: str | Path = None, @@ -485,8 +543,8 @@ def main( Directory for output files. slice_selection : str, optional Which slices to process ('middle', 'all', or specific slice number), by default 'middle'. - debug_output_dir : str or Path, optional - Directory for debug outputs, by default None. + qc_output_dir : str or Path, optional + Directory for quality control outputs, by default None. verbose : bool, optional Flag for verbose output, by default False. num_thickness_points : int, optional @@ -499,8 +557,8 @@ def main( Gaussian sigma for smoothing during contour detection, by default 5. save_template : str or Path, optional Directory path where to save contours.txt and thickness_values.txt files, by default None. - cpu : bool, optional - Force CPU usage even when CUDA is available, by default False. + device : str, optional + Device to run inference on ('auto', 'cpu', 'cuda', or 'cuda:X'), by default 'auto'. upright_volume_path : str or Path, optional Path to save upright volume, by default None. segmentation_path : str or Path, optional @@ -523,8 +581,8 @@ def main( Path to save VTK file, by default None. orig_space_segmentation_path : str or Path, optional Path to save segmentation in original space, by default None. - debug_image_path : str or Path, optional - Path to save debug images, by default None. + qc_image_path : str or Path, optional + Path to save QC images, by default None. thickness_image_path : str or Path, optional Path to save thickness visualization, by default None. softlabels_cc_path : str or Path, optional @@ -569,7 +627,7 @@ def main( in_mri_path = Path(in_mri_path) aseg_path = Path(aseg_path) output_dir = Path(output_dir) - debug_output_dir = Path(debug_output_dir) if debug_output_dir else None + qc_output_dir = Path(qc_output_dir) if qc_output_dir else None save_template = Path(save_template) if save_template else None # Validate subdivision fractions @@ -599,7 +657,10 @@ def main( raise ValueError("MRI is not conformed, please run conform.py or mri_convert to conform the image.") # load models - device = torch.device("cuda" if torch.cuda.is_available() and not cpu else "cpu") + if device == "auto": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + device = torch.device(device) logger.info(f"Using device: {device}") logger.info("Loading models") @@ -682,7 +743,7 @@ def main( subdivisions=subdivisions, subdivision_method=subdivision_method, contour_smoothing=contour_smoothing, - debug_image_path=debug_image_path, + qc_image_path=qc_image_path, one_debug_image=True, surf_file_path=surf_file_path, overlay_file_path=overlay_file_path, @@ -847,9 +908,14 @@ def main( options = options_parse() main_args = vars(options) + # Remove parser_defaults arguments that are not needed by main() + main_args.pop("sd", None) + main_args.pop("out_dir", None) + main_args.pop("sid", None) + # Rename keys to match main function parameters - main_args["in_mri_path"] = main_args.pop("in_mri") - main_args["aseg_path"] = main_args.pop("aseg") + main_args["in_mri_path"] = main_args.pop("t1") + main_args["aseg_path"] = main_args.pop("aseg_name") main_args["output_dir"] = main_args.pop("subject_dir", ".") main(**main_args) diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/cc_postprocessing.py index a04caf54..e11e490b 100644 --- a/CorpusCallosum/shape/cc_postprocessing.py +++ b/CorpusCallosum/shape/cc_postprocessing.py @@ -274,7 +274,7 @@ def process_slices( subdivisions: list[float], subdivision_method: str, contour_smoothing: float, - debug_image_path: str | None = None, + qc_image_path: str | None = None, one_debug_image: bool = False, thickness_image_path: str | None = None, vox_size: tuple[float, float, float] | None = None, @@ -310,8 +310,8 @@ def process_slices( Method for contour subdivision. contour_smoothing : float Gaussian sigma for contour smoothing. - debug_image_path : str or None, optional - Path for debug visualization image, by default None. + qc_image_path : str or None, optional + Path for QC visualization image, by default None. one_debug_image : bool, optional Whether to save only one debug image, by default False. thickness_image_path : str or None, optional @@ -369,13 +369,13 @@ def process_slices( contour_with_thickness[1], start_end_idx=(anterior_endpoint_idx, posterior_endpoint_idx)) - if result is not None and debug_image_path is not None: + if result is not None and qc_image_path is not None: slice_results.append(result) # Create visualization if verbose: - logger.info(f"Saving segmentation qc image to {debug_image_path}") - IO_processes.append(create_visualization(subdivision_method, result, midslices, - debug_image_path, ac_coords, pc_coords, vox_size[0])) + logger.info(f"Saving segmentation qc image to {qc_image_path}") + IO_processes.append(create_visualization(subdivision_method, result, midslices, + qc_image_path, ac_coords, pc_coords, vox_size[0])) else: num_slices = segmentation.shape[0] cc_mesh = CC_Mesh(num_slices=num_slices) @@ -418,22 +418,22 @@ def process_slices( if (one_debug_image and slice_idx == num_slices // 2) or not one_debug_image: if not one_debug_image: - debug_path_base, debug_path_ext = str(debug_image_path).rsplit('.', 1) - debug_path_with_postfix = f"{debug_path_base}_slice_{slice_idx}" - - debug_output_path_slice = Path(f"{debug_path_with_postfix}.{debug_path_ext}") - debug_output_path_slice = debug_output_path_slice.with_suffix('.png') + qc_path_base, qc_path_ext = str(qc_image_path).rsplit('.', 1) + qc_path_with_postfix = f"{qc_path_base}_slice_{slice_idx}" + + qc_output_path_slice = Path(f"{qc_path_with_postfix}.{qc_path_ext}") + qc_output_path_slice = qc_output_path_slice.with_suffix('.png') else: - debug_output_path_slice = debug_image_path + qc_output_path_slice = qc_image_path if verbose: - logger.info(f"Saving segmentation qc image to {debug_output_path_slice}") + logger.info(f"Saving segmentation qc image to {qc_output_path_slice}") current_slice_in_volume = midslices.shape[0] // 2 - num_slices // 2 + slice_idx # Create visualization for this slice - IO_processes.append(create_visualization(subdivision_method, result, - midslices[current_slice_in_volume:current_slice_in_volume+1], - debug_output_path_slice, ac_coords, pc_coords, + IO_processes.append(create_visualization(subdivision_method, result, + midslices[current_slice_in_volume:current_slice_in_volume+1], + qc_output_path_slice, ac_coords, pc_coords, vox_size[0], f' (Slice {slice_idx})')) if save_template is not None: From 6024d8a9a2c117b1cbf6d2c1336998deaebb9c79 Mon Sep 17 00:00:00 2001 From: David Kuegler Date: Wed, 12 Nov 2025 01:17:17 +0100 Subject: [PATCH 26/68] Various documentation and formatting changes as well as optimizations during review + merge --- CorpusCallosum/README.md | 6 +- CorpusCallosum/__init__.py | 3 - CorpusCallosum/cc_visualization.py | 52 +++-- CorpusCallosum/data/constants.py | 2 +- CorpusCallosum/data/fsaverage_cc_template.py | 13 +- .../data/generate_fsaverage_centroids.py | 2 +- CorpusCallosum/data/read_write.py | 21 +- CorpusCallosum/fastsurfer_cc.py | 203 ++++++++++-------- .../localization/localization_inference.py | 54 ++--- .../registration/mapping_helpers.py | 184 ++++++++-------- .../segmentation/segmentation_inference.py | 66 ++---- .../segmentation_postprocessing.py | 21 +- CorpusCallosum/shape/cc_mesh.py | 9 +- CorpusCallosum/shape/cc_postprocessing.py | 4 +- CorpusCallosum/shape/cc_thickness.py | 2 +- .../transforms/localization_transforms.py | 23 +- .../transforms/segmentation_transforms.py | 12 +- CorpusCallosum/visualization/visualization.py | 43 ++-- FastSurferCNN/download_checkpoints.py | 5 +- doc/conf.py | 3 +- env/fastsurfer.yml | 1 + 21 files changed, 344 insertions(+), 385 deletions(-) diff --git a/CorpusCallosum/README.md b/CorpusCallosum/README.md index 1c2c3162..35605674 100644 --- a/CorpusCallosum/README.md +++ b/CorpusCallosum/README.md @@ -1,7 +1,7 @@ # Corpus Callosum Pipeline A deep learning-based pipeline for automated segmentation, analysis, and shape analysis of the corpus callosum in brain MRI scans. -Also segments the fornix, localizes the AC and PC and standardizes the orientation of the brain. +Also segments the fornix, localizes the anterior and posterior commissure (AC and PC) and standardizes the orientation of the brain. ## Overview @@ -15,7 +15,9 @@ This pipeline combines localization and segmentation deep learning models to: ## Quickstart -``` python3 fastsurfer_cc.py --subject_dir /path/to/fastsurfer/output --verbose ``` +```bash +python3 fastsurfer_cc.py --subject_dir /path/to/fastsurfer/output --verbose +` `` Gives all standard outputs. Then corpus callosum morphometry can be found at `stats/callosum.CC.midslice.json`, including 100 thickness measurements and areas of sub-segments. Visualization will be placed in `/path/to/fastsurfer/output/qc_snapshots`. For more detailed info see the following sections. diff --git a/CorpusCallosum/__init__.py b/CorpusCallosum/__init__.py index 100ab63a..63db725a 100644 --- a/CorpusCallosum/__init__.py +++ b/CorpusCallosum/__init__.py @@ -13,11 +13,8 @@ # limitations under the License. __all__ = [ - "config", "data", - "localization", "segmentation", "transforms", "utils", - "visualization", ] diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index 8461616d..3481c3cb 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -1,5 +1,7 @@ import argparse +import sys from pathlib import Path +from typing import Literal import numpy as np @@ -9,8 +11,8 @@ from CorpusCallosum.shape.cc_mesh import CC_Mesh -def options_parse() -> argparse.Namespace: - """Parse command line arguments for the visualization pipeline.""" +def make_parser() -> argparse.ArgumentParser: + """Create a command line parser for the visualization pipeline.""" parser = argparse.ArgumentParser(description="Visualize corpus callosum from template files.") parser.add_argument("--contours", type=str, required=False, help="Path to contours.txt file", default=None) parser.add_argument("--thickness", type=str, required=True, help="Path to thickness_values.txt file") @@ -38,12 +40,17 @@ def options_parse() -> argparse.Namespace: nargs=2, default=None, metavar=("MIN", "MAX"), - help="Optional fixed range for the colorbar (min max)", + required=False, + help="Specify the range for the colorbar (2 values: min max). Defaults to automatic choice.", ) parser.add_argument("--legend", type=str, default="Thickness (mm)", help="Legend for the colorbar") parser.add_argument("--twoD", action="store_true", help="Generate 2D visualization instead of 3D mesh") - args = parser.parse_args() + return parser + +def options_parse() -> argparse.Namespace: + """Parse command line arguments for the pipeline.""" + args = make_parser().parse_args() # Create output directory if it doesn't exist Path(args.output_dir).mkdir(parents=True, exist_ok=True) @@ -62,7 +69,7 @@ def main( color_range: tuple[float, float] | None = None, legend: str | None = None, twoD: bool = False, -) -> None: +) -> Literal[0] | str: """Main function to visualize corpus callosum from template files. This function loads contours and thickness values from template files, @@ -146,21 +153,24 @@ def main( cc_mesh.write_vtk(str(output_dir / "cc_mesh.vtk")) cc_mesh.write_fssurf(str(output_dir / "cc_mesh.fssurf")) cc_mesh.write_overlay(str(output_dir / "cc_mesh_overlay.curv")) - cc_mesh.snap_cc_picture(str(output_dir / "cc_mesh_snap.png")) - + try: + cc_mesh.snap_cc_picture(str(output_dir / "cc_mesh_snap.png")) + except RuntimeError: + return ("The cc_visualization script requires whippersnappy>=1.3.1 to makes screenshots, install with " + "`pip install whippersnappy>=1.3.1` !") + return 0 if __name__ == "__main__": - options = options_parse() - main_args = { - "contours_path": options.contours, - "thickness_path": options.thickness, - "measurement_points_path": options.measurement_points, - "output_dir": options.output_dir, - "resolution": options.resolution, - "smoothing_window": options.smoothing_window, - "colormap": options.colormap, - "color_range": options.color_range, - "legend": options.legend, - "twoD": options.twoD, - } - main(**main_args) + options = make_parser().parse_args() + sys.exit(main( + contours_path=options.contours, + thickness_path=options.thickness, + measurement_points_path=options.measurement_points, + output_dir=options.output_dir, + resolution=options.resolution, + smooth_iterations=options.smooth_iterations, + colormap=options.colormap, + color_range=options.color_range, + legend=options.legend, + twoD=options.twoD, + )) diff --git a/CorpusCallosum/data/constants.py b/CorpusCallosum/data/constants.py index dd20a8c5..6cc2101a 100644 --- a/CorpusCallosum/data/constants.py +++ b/CorpusCallosum/data/constants.py @@ -22,7 +22,7 @@ FSAVERAGE_MIDDLE = 128 # Middle slice index in fsaverage space CC_LABEL = 192 # Label value for corpus callosum in segmentation FORNIX_LABEL = 250 # Label value for fornix in segmentation -SUBSEGEMNT_LABELS = [251, 252, 253, 254, 255] # labels for subsegments in segmentation +SUBSEGMENT_LABELS = [251, 252, 253, 254, 255] # labels for subsegments in segmentation FASTSURFER_ROOT = Path(__file__).parent.parent.parent # TODO: use FastSurfer function for this diff --git a/CorpusCallosum/data/fsaverage_cc_template.py b/CorpusCallosum/data/fsaverage_cc_template.py index 3e0cd854..497c52ed 100644 --- a/CorpusCallosum/data/fsaverage_cc_template.py +++ b/CorpusCallosum/data/fsaverage_cc_template.py @@ -19,7 +19,9 @@ import numpy as np from scipy import ndimage +from CorpusCallosum.data import constants from CorpusCallosum.shape.cc_postprocessing import process_slice +from FastSurferCNN.utils.brainvolstats import mask_in_array def smooth_contour(contour: tuple[np.ndarray, np.ndarray], window_size: int = 5) -> tuple[np.ndarray, np.ndarray]: @@ -93,7 +95,7 @@ def load_fsaverage_cc_template() -> tuple[ fsaverage_seg_path = freesurfer_home / 'subjects' / 'fsaverage' / 'mri' / 'aparc+aseg.mgz' fsaverage_seg = nib.load(fsaverage_seg_path) - segmentation = fsaverage_seg.get_fdata() + segmentation = np.asarray(fsaverage_seg.dataobj) PC = np.array([131, 99]) AC = np.array([135, 130]) @@ -101,11 +103,7 @@ def load_fsaverage_cc_template() -> tuple[ midslice = segmentation.shape[0]//2 +1 - cc_mask = segmentation[midslice] == 251 - cc_mask |= segmentation[midslice] == 252 - cc_mask |= segmentation[midslice] == 253 - cc_mask |= segmentation[midslice] == 254 - cc_mask |= segmentation[midslice] == 255 + cc_mask = mask_in_array(segmentation[midslice], constants.SUBSEGMENT_LABELS) # Smooth the CC mask to reduce noise and irregularities @@ -120,8 +118,7 @@ def load_fsaverage_cc_template() -> tuple[ cc_mask_smoothed = cc_mask_smoothed > 0.5 # Use the smoothed mask for further processing - cc_mask = cc_mask_smoothed.astype(int) - cc_mask[cc_mask > 0] = 192 + cc_mask = cc_mask_smoothed.astype(int) * 192 (_, contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx) = process_slice(segmentation=cc_mask[None], diff --git a/CorpusCallosum/data/generate_fsaverage_centroids.py b/CorpusCallosum/data/generate_fsaverage_centroids.py index 9b69beb5..c43072f7 100644 --- a/CorpusCallosum/data/generate_fsaverage_centroids.py +++ b/CorpusCallosum/data/generate_fsaverage_centroids.py @@ -163,4 +163,4 @@ def main() -> None: if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py index 52df104d..00442944 100644 --- a/CorpusCallosum/data/read_write.py +++ b/CorpusCallosum/data/read_write.py @@ -15,6 +15,7 @@ import json import multiprocessing from pathlib import Path +from typing import overload import nibabel as nib import numpy as np @@ -52,8 +53,15 @@ def run_in_background(function: callable, debug: bool = False, *args, **kwargs) return process +@overload +def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: None = None) -> dict[int, np.ndarray]: + ... -def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: list[int] | None = None) -> dict[int, np.ndarray]: +@overload +def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: list[int]) -> tuple[dict[int, np.ndarray], list[int]]: + ... + +def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: list[int] | None = None): """Get centroids of segmentation labels in RAS coordinates. Parameters @@ -187,12 +195,7 @@ def load_fsaverage_centroids(centroids_path: str | Path) -> dict[int, np.ndarray centroids_data = json.load(f) # Convert string keys back to integers and lists back to numpy arrays - centroids = {} - for label_str, centroid_list in centroids_data.items(): - label_id = int(label_str) - centroids[label_id] = np.array(centroid_list) - - return centroids + return {int(label): np.array(centroid) for label, centroid in centroids_data.items()} def load_fsaverage_affine(affine_path: str | Path) -> np.ndarray: @@ -270,8 +273,8 @@ def load_fsaverage_data(data_path: str | Path) -> tuple[np.ndarray, dict, np.nda if "header" not in data: raise ValueError("Required field 'header' missing from data file") - header_fields = ["dims", "delta", "Mdc", "Pxyz_c"] - for field in header_fields: + required_header_fields = ["dims", "delta", "Mdc", "Pxyz_c"] + for field in required_header_fields: if field not in data["header"]: raise ValueError(f"Required header field missing: {field}") diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index ab705a52..d0208997 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -1,11 +1,28 @@ +#!/usr/bin/env python3 +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import argparse import json from pathlib import Path +from typing import Literal # import warnings warnings.filterwarnings("ignore", message="TypedStorage is deprecated") import nibabel as nib import numpy as np import torch +from monai.networks.nets import DenseNet +from numpy import typing as npt import FastSurferCNN.utils.logging as logging from CorpusCallosum.data.constants import ( @@ -39,14 +56,17 @@ process_slices, ) from FastSurferCNN.data_loader.conform import is_conform +from FastSurferCNN.utils.common import find_device from recon_surf import lta from recon_surf.align_points import find_rigid logger = logging.get_logger(__name__) +SliceSelection = Literal["middle", "all"] | int +SubdivisionMethod = Literal["shape", "vertical", "angular", "eigenvector"] + - -def options_parse() -> argparse.Namespace: - """Parse command line arguments for the pipeline.""" +def make_parser() -> argparse.ArgumentParser: + """Create the argument parse object for the pipeline.""" parser = argparse.ArgumentParser() # Specify subject directory + subject ID, OR specify individual MRI and segmentation files + output paths @@ -81,10 +101,10 @@ def options_parse() -> argparse.Namespace: help="Select device to run inference on: cpu, or cuda (= Nvidia gpu) or specify a certain gpu (e.g. cuda:1), \ Default: auto", ) - + parser.add_argument( - "--num_thickness_points", - type=int, default=100, + "--num_thickness_points", + type=int, default=100, help="Number of points for thickness estimation." ) parser.add_argument( @@ -96,7 +116,6 @@ def options_parse() -> argparse.Namespace: ) parser.add_argument( "--subdivision_method", - type=str, default="shape", help="Method for contour subdivision. \ Options: shape (Intercallosal subdivision perpendicular to intercallosal line), vertical \ @@ -112,6 +131,10 @@ def options_parse() -> argparse.Namespace: help="Gaussian sigma for smoothing during contour detection. Higher values mean a smoother" " CC outline, at the cost of precision. (default: 5)", ) + def _slice_selection(a: str) -> SliceSelection: + if a.lower() in ("middle", "all"): + return a.lower() + return int(a) parser.add_argument( "--slice_selection", type=str, @@ -128,7 +151,7 @@ def options_parse() -> argparse.Namespace: ######## OUTPUT PATHS ######### # 4. Options for advanced, technical parameters - advanced = parser.add_argument_group(title="Advanced options", + advanced = parser.add_argument_group(title="Advanced options", description="Custom output paths, useful if no standard case directory is used.") advanced.add_argument("--qc_output_dir", type=Path, @@ -190,7 +213,7 @@ def options_parse() -> argparse.Namespace: default=None, ) advanced.add_argument( - "--save_template", + "--save_template_dir", type=Path, help="Directory path where to save contours.txt and thickness_values.txt files. \ These files can be used to visualize the CC shape and volume in 3D.", @@ -217,7 +240,7 @@ def options_parse() -> argparse.Namespace: advanced.add_argument( "--cc_html_path", type=Path, - help=f"Path to CC 3D visualization for CC HTML file (default: subject_dir/{STANDARD_OUTPUT_PATHS['cc_html']})", + help=f"Path to CC 3D visualization for CC HTML file (default: subject_dir/{STANDARD_OUTPUT_PATHS['cc_html']})", default=None, ) advanced.add_argument( @@ -248,9 +271,14 @@ def options_parse() -> argparse.Namespace: default=None, ) ############ END OF OUTPUT PATHS ############ - + return parser + + +def options_parse() -> argparse.Namespace: + """Parse command line arguments for the pipeline.""" + parser = make_parser() args = parser.parse_args() # Reconstruct subject_dir from sd and sid (but sd might be stored as out_dir by parser_defaults) @@ -298,7 +326,7 @@ def options_parse() -> argparse.Namespace: # Set default output paths if not provided for key, value in STANDARD_OUTPUT_PATHS.items(): if not getattr(args, f"{key}_path") and value is not None: - setattr(args, f"{key}_path", str(subject_dir_path / value)) + setattr(args, f"{key}_path", subject_dir_path / value) # Set output_dir to subject_dir args.output_dir = str(subject_dir_path) @@ -312,8 +340,8 @@ def options_parse() -> argparse.Namespace: return args -def centroid_registration(aseg_nib: nib.Nifti1Image, verbose: bool = False) -> tuple[ - np.ndarray, np.ndarray, np.ndarray, nib.Nifti1Header, np.ndarray +def centroid_registration(aseg_nib: nib.Nifti1Image) -> tuple[ + npt.NDArray[float], npt.NDArray[float], npt.NDArray[float], nib.Nifti1Header, npt.NDArray[float] ]: """Perform centroid-based registration between subject and fsaverage space. @@ -324,8 +352,6 @@ def centroid_registration(aseg_nib: nib.Nifti1Image, verbose: bool = False) -> t ---------- aseg_nib : nibabel.Nifti1Image Subject's segmentation image. - verbose : bool, optional - Whether to print progress information, by default False. Returns ------- @@ -346,8 +372,7 @@ def centroid_registration(aseg_nib: nib.Nifti1Image, verbose: bool = False) -> t to perform the registration. It matches corresponding anatomical structures between the subject's segmentation and fsaverage space. """ - if verbose: - print("Centroid registration") + logger.info("Starting centroid registration") # Load pre-computed fsaverage centroids and data from static files centroids_dst = load_fsaverage_centroids(FSAVERAGE_CENTROIDS_PATH) @@ -365,25 +390,21 @@ def centroid_registration(aseg_nib: nib.Nifti1Image, verbose: bool = False) -> t orig_fsaverage_ras2ras = find_rigid(p_mov=centroids_mov.T, p_dst=centroids_dst.T) # make affine that increases resolution to orig resolution - resolution_orig = aseg_nib.header.get_zooms()[0] - resolution_trans = np.eye(4) - resolution_trans[0, 0] = resolution_orig - resolution_trans[1, 1] = resolution_orig - resolution_trans[2, 2] = resolution_orig + resolution_trans = np.diagflat(list(aseg_nib.header.get_zooms()[:3]) + [1]) orig_fsaverage_vox2vox = ( np.linalg.inv(resolution_trans @ fsaverage_affine) @ orig_fsaverage_ras2ras @ aseg_nib.affine ) fsaverage_hires_affine = resolution_trans @ fsaverage_affine - + logger.info("Centroid registration successful!") return orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, fsaverage_hires_affine, fsaverage_header, vox2ras_tkr def localize_ac_pc( midslices: np.ndarray, aseg_nib: "nib.Nifti1Image", - orig_fsaverage_vox2vox: np.ndarray, - model_localization: "torch.nn.Module", + orig_fsaverage_vox2vox: npt.NDArray[float], + model_localization: DenseNet, slices_to_analyze: int ) -> tuple[np.ndarray, np.ndarray]: """Localize anterior and posterior commissure points in the brain. @@ -399,7 +420,7 @@ def localize_ac_pc( Subject's segmentation image. orig_fsaverage_vox2vox : np.ndarray Transformation matrix to fsaverage space. - model_localization : torch.nn.Module + model_localization : DenseNet Trained model for AC-PC detection. slices_to_analyze : int Number of slices to process. @@ -413,7 +434,7 @@ def localize_ac_pc( """ # get center of third ventricle from aseg and map to fsaverage space - third_ventricle_mask = aseg_nib.get_fdata() == 4 + third_ventricle_mask = np.asarray(aseg_nib.dataobj) == 4 third_ventricle_center = np.argwhere(third_ventricle_mask).mean(axis=0) third_ventricle_center_vox = apply_transform_to_pt(third_ventricle_center, orig_fsaverage_vox2vox, inv=False) @@ -431,12 +452,12 @@ def localize_ac_pc( def segment_cc( midslices: np.ndarray, - ac_coords: np.ndarray, - pc_coords: np.ndarray, + ac_coords: npt.NDArray[float], + pc_coords: npt.NDArray[float], aseg_nib: "nib.Nifti1Image", model_segmentation: "torch.nn.Module", - slices_to_analyze: int -) -> tuple[np.ndarray, np.ndarray]: + slices_to_analyze: int, +) -> tuple[npt.NDArray[bool], npt.NDArray[float]]: """Segment the corpus callosum using a trained model. Performs corpus callosum segmentation on mid-sagittal slices using a trained model, @@ -460,38 +481,33 @@ def segment_cc( Returns ------- - tuple[np.ndarray, np.ndarray] - - segmentation : Binary segmentation of the corpus callosum. - - outputs_soft : Soft segmentation probabilities. + segmentation : np.ndarray + Binary segmentation of the corpus callosum. + outputs_soft : np.ndarray + Soft segmentation probabilities. """ # get 5 mm of slices output with 9 slices per inference - midslices_middle = midslices.shape[0] // 2 - middle_slices_segmentation = midslices[ - midslices_middle - slices_to_analyze // 2 - 4 : midslices_middle + slices_to_analyze // 2 + 5 - ] - segmentation, inputs, outputs_avg, outputs_soft = segmentation_inference.run_inference_on_slice( + midslices_start = midslices.shape[0] // 2 - slices_to_analyze // 2 - 4 + middle_slices_slab = midslices[midslices_start:midslices_start + slices_to_analyze + 9] + pre_clean_segmentation, inputs, outputs_avg, outputs_soft = segmentation_inference.run_inference_on_slice( model_segmentation, - middle_slices_segmentation, + middle_slices_slab, AC_center=ac_coords, PC_center=pc_coords, voxel_size=aseg_nib.header.get_zooms()[0], ) - pre_clean_segmentation = segmentation.copy() - segmentation, cc_volume_mask = segmentation_postprocessing.clean_cc_segmentation(segmentation) + segmentation, cc_volume_mask = segmentation_postprocessing.clean_cc_segmentation(pre_clean_segmentation) # print a warning if the cc_volume_mask touches the edge of the segmentation if ( - np.any(cc_volume_mask[:, 0, :]) - or np.any(cc_volume_mask[:, -1, :]) - or np.any(cc_volume_mask[:, :, 0]) - or np.any(cc_volume_mask[:, :, -1]) + np.any(cc_volume_mask[:, [0, -1]]) + or np.any(cc_volume_mask[:, :, [0, -1]]) ): - print("Warning: CC voume mask touches the edge of the segmentation field-of-view, CC might be truncated") + logger.warning("CC volume mask touches the edge of the segmentation field-of-view, CC might be truncated") # get voxels that were removed during cleaning - removed_voxels = pre_clean_segmentation != segmentation - outputs_soft[removed_voxels, 1] = 0 + outputs_soft[pre_clean_segmentation != segmentation, 1] = 0 return segmentation, outputs_soft @@ -499,7 +515,7 @@ def segment_cc( def main( - in_mri_path: str | Path, + t1_path: str | Path, aseg_path: str | Path, output_dir: str | Path, slice_selection: str = "middle", @@ -507,9 +523,9 @@ def main( verbose: bool = False, num_thickness_points: int = 100, subdivisions: list[float] | None = None, - subdivision_method: str = "shape", + subdivision_method: SubdivisionMethod = "shape", contour_smoothing: float = 5, - save_template: str | Path | None = None, + save_template_dir: str | Path | None = None, device: str = "auto", upright_volume_path: str | Path = None, segmentation_path: str | Path = None, @@ -535,9 +551,9 @@ def main( Parameters ---------- - in_mri_path : str or Path + conf_name : str or Path Path to input MRI file. - aseg_path : str or Path + aseg_name : str or Path Path to input segmentation file. output_dir : str or Path Directory for output files. @@ -550,47 +566,48 @@ def main( num_thickness_points : int, optional Number of points for thickness estimation, by default 100. subdivisions : list[float], optional - List of subdivision fractions for CC subsegmentation, by default None. - subdivision_method : str, optional - Method for contour subdivision ('shape', 'vertical', 'angular', 'eigenvector'), by default 'shape'. - contour_smoothing : float, optional - Gaussian sigma for smoothing during contour detection, by default 5. - save_template : str or Path, optional - Directory path where to save contours.txt and thickness_values.txt files, by default None. + List of subdivision fractions for CC subsegmentation. + subdivision_method : any of "shape", "vertical", "angular", "eigenvector", default="shape" + Method for contour subdivision. + contour_smoothing : float, default=5 + Gaussian sigma for smoothing during contour detection. + save_template_dir : str or Path, optional + Directory path where to save contours.txt and thickness_values.txt files. \ + These files can be used to visualize the CC shape and volume in 3D. device : str, optional Device to run inference on ('auto', 'cpu', 'cuda', or 'cuda:X'), by default 'auto'. upright_volume_path : str or Path, optional - Path to save upright volume, by default None. + Path to save upright volume. segmentation_path : str or Path, optional - Path to save segmentation, by default None. + Path to save segmentation. postproc_results_path : str or Path, optional - Path to save post-processing results, by default None. + Path to save post-processing results. cc_markers_path : str or Path, optional - Path to save CC markers, by default None. + Path to save CC markers. upright_lta_path : str or Path, optional - Path to save upright LTA transform, by default None. + Path to save upright LTA transform. orient_volume_lta_path : str or Path, optional - Path to save orientation transform, by default None. + Path to save orientation transform. surf_file_path : str or Path, optional - Path to save surface file, by default None. + Path to save surface file. overlay_file_path : str or Path, optional - Path to save overlay file, by default None. + Path to save overlay file. cc_html_path : str or Path, optional - Path to save HTML visualization, by default None. + Path to save HTML visualization. vtk_file_path : str or Path, optional - Path to save VTK file, by default None. + Path to save VTK file. orig_space_segmentation_path : str or Path, optional Path to save segmentation in original space, by default None. qc_image_path : str or Path, optional Path to save QC images, by default None. thickness_image_path : str or Path, optional - Path to save thickness visualization, by default None. + Path to save thickness visualization. softlabels_cc_path : str or Path, optional - Path to save CC soft labels, by default None. + Path to save CC soft labels. softlabels_fn_path : str or Path, optional - Path to save fornix soft labels, by default None. + Path to save fornix soft labels. softlabels_background_path : str or Path, optional - Path to save background soft labels, by default None. + Path to save background soft labels. Notes ----- @@ -619,27 +636,28 @@ def main( logging.setup_logging(None) # Log to stdout only logger.info("Starting corpus callosum analysis pipeline") - logger.info(f"Input MRI: {in_mri_path}") + logger.info(f"Input MRI: {t1_path}") logger.info(f"Input segmentation: {aseg_path}") logger.info(f"Output directory: {output_dir}") # Convert all paths to Path objects - in_mri_path = Path(in_mri_path) + t1_path = Path(t1_path) aseg_path = Path(aseg_path) - output_dir = Path(output_dir) - qc_output_dir = Path(qc_output_dir) if qc_output_dir else None - save_template = Path(save_template) if save_template else None + if output_dir is not None: + output_dir = Path(output_dir) + if save_template_dir: + save_template_dir = Path(save_template_dir) # Validate subdivision fractions - for i in subdivisions: - if i < 0 or i > 1: - logger.error(f"Error: Subdivision fractions must be between 0 and 1, but got: {i}") - raise ValueError(f"Subdivision fractions must be between 0 and 1, but got: {i}") + if any(i < 0 or i > 1 for i in subdivisions): + logger.error(f"Error: Subdivision fractions must be between 0 and 1, but got: {subdivisions}") + import sys + sys.exit(1) #### setup variables IO_processes = [] - orig = nib.load(in_mri_path) + orig = nib.load(t1_path) # 5 mm around the midplane slices_to_analyze = int(np.ceil(5 / orig.header.get_zooms()[0])) @@ -672,7 +690,7 @@ def main( logger.info("Performing centroid registration to fsaverage space") (orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, fsaverage_hires_affine, fsaverage_header, fsaverage_vox2ras_tkr) = centroid_registration( - aseg_nib, verbose=False + aseg_nib ) if verbose: @@ -728,14 +746,12 @@ def main( # Create a temporary segmentation image with proper affine for enhanced postprocessing - temp_seg_affine = fsaverage_hires_affine @ np.linalg.inv(np.eye(4)) - # Process slices based on selection mode logger.info(f"Processing slices with selection mode: {slice_selection}") slice_results, slice_io_processes = process_slices( segmentation=segmentation, slice_selection=slice_selection, - temp_seg_affine=temp_seg_affine, + temp_seg_affine=fsaverage_hires_affine, midslices=midslices, ac_coords=ac_coords, pc_coords=pc_coords, @@ -753,11 +769,10 @@ def main( vox_size=orig.header.get_zooms(), vox2ras_tkr=fsaverage_vox2ras_tkr, verbose=verbose, - save_template=save_template, + save_template=save_template_dir, ) IO_processes.extend(slice_io_processes) - outer_contours = [slice_result['split_contours'][0] for slice_result in slice_results] if len(outer_contours) > 1 and not check_area_changes(outer_contours, verbose=True): @@ -894,7 +909,7 @@ def main( ) logger.info(f"Saving LTA to standardized space: {orient_volume_lta_path}") lta.writeLTA( - orient_volume_lta_path, orig_to_standardized_ras2ras, in_mri_path, orig.header, in_mri_path, orig.header + orient_volume_lta_path, orig_to_standardized_ras2ras, t1_path, orig.header, t1_path, orig.header ) for process in IO_processes: @@ -914,7 +929,7 @@ def main( main_args.pop("sid", None) # Rename keys to match main function parameters - main_args["in_mri_path"] = main_args.pop("t1") + main_args["t1_path"] = main_args.pop("t1") main_args["aseg_path"] = main_args.pop("aseg_name") main_args["output_dir"] = main_args.pop("subject_dir", ".") diff --git a/CorpusCallosum/localization/localization_inference.py b/CorpusCallosum/localization/localization_inference.py index df108ace..105a57c0 100644 --- a/CorpusCallosum/localization/localization_inference.py +++ b/CorpusCallosum/localization/localization_inference.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable from pathlib import Path import numpy as np @@ -43,8 +44,9 @@ def load_model(device: torch.device | None = None) -> DenseNet: DenseNet Loaded and initialized model in evaluation mode """ - if device is None: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device is None or device == "auto": + from FastSurferCNN.utils.common import find_device + device = find_device(device) # Initialize model architecture (must match training) model = DenseNet( # densenet201 @@ -96,7 +98,7 @@ def get_transforms() -> transforms.Compose: CropAroundACPCFixedSize( keys=['image'], fixed_size=(64, 64), - random_translate=0 + random_translate=0, ), ] return transforms.Compose(tr) @@ -139,7 +141,7 @@ def preprocess_volume( return transformed -def run_inference(model: torch.nn.Module, +def run_inference(model: DenseNet, image_volume: np.ndarray, third_ventricle_center: np.ndarray, device: torch.device | None = None, @@ -163,12 +165,14 @@ def run_inference(model: torch.nn.Module, Returns ------- - tuple - Tuple containing: - - np.ndarray: Predicted PC coordinates - - np.ndarray: Predicted AC coordinates - - np.ndarray: Processed input images - - tuple: Crop offsets (left, top) + pc_ccord : np.ndarray + Predicted PC coordinates. + ac_coord : np.ndarray + Predicted AC coordinates. + image : np.ndarray + Processed input images. + crop_offsets : tuple[int, int] + Crop offsets (left, top). """ if device is None: device = next(model.parameters()).device @@ -188,13 +192,7 @@ def run_inference(model: torch.nn.Module, inputs = inputs.transpose(0, 1) batch_size, channels, height, width = inputs.shape - views = [] - for i in range(batch_size - 2): # -2 to ensure we have 3 slices per view - view = inputs[i:i+3] # Take 3 consecutive slices - view = view.reshape(1, 3*channels, height, width) # Reshape to combine slices into channels - views.append(view) - - inputs = torch.cat(views, dim=0) # Stack all views into batch dimension + inputs = inputs.unfold(0, 3, 1).swapdims(0, 1).reshape(-1, 3*channels, height, width) # Run inference @@ -207,20 +205,13 @@ def run_inference(model: torch.nn.Module, # dtype=torch.float32, # device=device) outputs = outputs * 64 - - outputs[:, 0] += t_dict['crop_left'] - outputs[:, 1] += t_dict['crop_top'] - outputs[:, 2] += t_dict['crop_left'] - outputs[:, 3] += t_dict['crop_top'] - - return (outputs[:,:2].cpu().numpy(), - outputs[:,2:].cpu().numpy(), - inputs.cpu().numpy(), - (t_dict['crop_left'], t_dict['crop_top'])) + t_crops = [[t_dict['crop_left'], t_dict['crop_top'], t_dict['crop_left'], t_dict['crop_top']]] + outs: np.ndarray = (outputs + torch.tensor(t_crops, dtype=outputs.dtype, device=outputs.device)).numpy() + return outs[:, :2], outs[:, 2:], inputs.numpy(), tuple(int(t_dict[k].item()) for k in ['crop_left', 'crop_top']) -def run_inference_on_slice(model: torch.nn.Module, +def run_inference_on_slice(model: DenseNet, image_slice: np.ndarray, center_pt: np.ndarray, debug_output: str | None = None) -> tuple[np.ndarray, np.ndarray]: @@ -239,9 +230,10 @@ def run_inference_on_slice(model: torch.nn.Module, Returns ------- - tuple[np.ndarray, np.ndarray] - Detected AC and PC coordinates as (ac_coords, pc_coords) - Each coordinate array has shape (2,) containing [y,x] positions + ac_coords : np.ndarray + Detected AC coordinates with shape (2,) containing its [y,x] positions. + pc_coords : np.ndarray + Detected PC coordinates with shape (2,) containing its [y,x] positions. """ # Run inference diff --git a/CorpusCallosum/registration/mapping_helpers.py b/CorpusCallosum/registration/mapping_helpers.py index f53381d2..99958ea0 100644 --- a/CorpusCallosum/registration/mapping_helpers.py +++ b/CorpusCallosum/registration/mapping_helpers.py @@ -3,30 +3,34 @@ import nibabel as nib import numpy as np import SimpleITK as sitk +from numpy import typing as npt from scipy.ndimage import affine_transform -import FastSurferCNN.utils.logging as logging +from FastSurferCNN.utils import logging logger = logging.get_logger(__name__) -def make_midplane_affine(orig_affine: np.ndarray, slices_to_analyze: int = 1, - offset: int = 4) -> np.ndarray: +def make_midplane_affine( + orig_affine: npt.NDArray[float], + slices_to_analyze: int = 1, + offset: int = 4, + ) -> npt.NDArray[float]: """Create affine transformation matrix for midplane slices. Parameters ---------- orig_affine : np.ndarray - Original image affine matrix (4x4) - slices_to_analyze : int, optional - Number of slices to analyze around midplane, by default 1 - offset : int, optional - Additional offset in x direction, by default 4 + Original image affine matrix (4x4). + slices_to_analyze : int, default=1 + Number of slices to analyze around midplane. + offset : int, default=4 + Additional offset in x direction. Returns ------- np.ndarray - 4x4 affine matrix for midplane slices + 4x4 affine matrix for midplane slices. """ # Create translation matrix to center on midplane orig_to_seg = np.eye(4) @@ -38,7 +42,7 @@ def make_midplane_affine(orig_affine: np.ndarray, slices_to_analyze: int = 1, return seg_affine -def correct_nodding(ac_pt: np.ndarray, pc_pt: np.ndarray) -> np.ndarray: +def correct_nodding(ac_pt: npt.NDArray[float], pc_pt: npt.NDArray[float]) -> npt.NDArray[float]: """Calculate rotation matrix to correct head nodding. Calculates rotation matrix to align AC-PC line with posterior direction, @@ -47,14 +51,14 @@ def correct_nodding(ac_pt: np.ndarray, pc_pt: np.ndarray) -> np.ndarray: Parameters ---------- ac_pt : np.ndarray - Coordinates of the anterior commissure point + Coordinates of the anterior commissure point. pc_pt : np.ndarray - Coordinates of the posterior commissure point + Coordinates of the posterior commissure point. Returns ------- np.ndarray - 3x3 rotation matrix to align AC-PC line with posterior direction + 3x3 rotation matrix to align AC-PC line with posterior direction. """ ac_pc_vec = pc_pt - ac_pt ac_pc_dist = np.linalg.norm(ac_pc_vec) @@ -90,25 +94,24 @@ def correct_nodding(ac_pt: np.ndarray, pc_pt: np.ndarray) -> np.ndarray: return rotation_matrix -def apply_transform_to_pt(pts: np.ndarray, T: np.ndarray, inv: bool = False) -> np.ndarray: +def apply_transform_to_pt(pts: npt.NDArray[float], T: npt.NDArray[float], inv: bool = False) -> npt.NDArray[float]: """Apply homogeneous transformation matrix to points. Parameters ---------- pts : np.ndarray - Point coordinates to transform, shape (3,) or (3, N) + Point coordinates to transform, shape (3,) or (3, N). T : np.ndarray - 4x4 homogeneous transformation matrix - inv : bool, optional - If True, applies inverse of transformation, by default False + 4x4 homogeneous transformation matrix. + inv : bool, default=False + If True, applies inverse of transformation. Returns ------- np.ndarray - Transformed point coordinates, shape (3,) or (3, N) + Transformed point coordinates, shape (3,) or (3, N). """ if inv: - T = T.copy() T = np.linalg.inv(T) if pts.ndim == 1: @@ -119,34 +122,35 @@ def apply_transform_to_pt(pts: np.ndarray, T: np.ndarray, inv: bool = False) -> def get_mapping_to_standard_space( orig: "nib.Nifti1Image", - ac_coords_3d: np.ndarray, - pc_coords_3d: np.ndarray, - orig_fsaverage_vox2vox: np.ndarray, -) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + ac_coords_3d: npt.NDArray[float], + pc_coords_3d: npt.NDArray[float], + orig_fsaverage_vox2vox: npt.NDArray[float], +) -> tuple[npt.NDArray[float], npt.NDArray[float], npt.NDArray[float], npt.NDArray[float], npt.NDArray[float]]: """Get transformations to map image to standard space. Parameters ---------- orig : nib.Nifti1Image - Original image + Original image. ac_coords_3d : np.ndarray - AC coordinates in 3D space + AC coordinates in 3D space. pc_coords_3d : np.ndarray - PC coordinates in 3D space + PC coordinates in 3D space. orig_fsaverage_vox2vox : np.ndarray - Transformation matrix from original to fsaverage space - output_dir : str or Path - Directory to save transformation files + Transformation matrix from original to fsaverage space. Returns ------- - tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray] - Contains: - - upright_volume : Upright transformed volume - - standardized_volume : Volume in standard space - - ac_coords_standardized : AC coordinates in standard space - - pc_coords_standardized : PC coordinates in standard space - - standardized_affine : Affine matrix for standard space + upright_volume : np.ndarray + Upright transformed volume. + standardized_volume : np.ndarray + Volume in standard space. + ac_coords_standardized : np.ndarray + AC coordinates in standard space. + pc_coords_standardized : np.ndarray + PC coordinates in standard space. + standardized_affine : np.ndarray + Affine matrix for standard space. """ image_center = np.array(orig.shape) / 2 @@ -156,9 +160,8 @@ def get_mapping_to_standard_space( # convert 2D nodding correction to 3D transformation matrix nod_correct_3d = np.eye(4) nod_correct_3d[1:3, 1:3] = nod_correct_2d[:2, :2] # Copy rotation part to y,z axes - nod_correct_3d[1:3, 3] = nod_correct_2d[ - :2, 2 - ] # Copy translation part to y,z axes (usually no translation) + # Copy translation part to y,z axes (usually no translation) + nod_correct_3d[1:3, 3] = nod_correct_2d[:2, 2] ac_coords_after_nodding = apply_transform_to_pt( ac_coords_3d, nod_correct_3d, inv=False @@ -168,9 +171,7 @@ def get_mapping_to_standard_space( ) ac_to_center_translation = np.eye(4) - ac_to_center_translation[0, 3] = image_center[0] - ac_coords_after_nodding[0] - ac_to_center_translation[1, 3] = image_center[1] - ac_coords_after_nodding[1] - ac_to_center_translation[2, 3] = image_center[2] - ac_coords_after_nodding[2] + ac_to_center_translation[:3, 3] = image_center - ac_coords_after_nodding # correct nodding ac_coords_standardized = apply_transform_to_pt( @@ -205,8 +206,8 @@ def get_mapping_to_standard_space( def apply_transform_to_volume( volume: np.ndarray, - transform: np.ndarray, - affine: np.ndarray, + transform: npt.NDArray[float], + affine: npt.NDArray[float], header: nib.freesurfer.mghformat.MGHHeader, output_path: str | Path | None = None, output_size: np.ndarray | None = None, @@ -217,24 +218,24 @@ def apply_transform_to_volume( Parameters ---------- volume : np.ndarray - Input volume data + Input volume data. transform : np.ndarray - Transformation matrix to apply + Transformation matrix to apply. affine : np.ndarray - Affine matrix for the output image + Affine matrix for the output image. header : nib.freesurfer.mghformat.MGHHeader - Header for the output image - output_path : str or Path or None, optional - Path to save transformed volume, by default None - output_size : np.ndarray or None, optional - Size of output volume, by default None (uses input size) - order : int, optional - Order of interpolation, by default 1 + Header for the output image. + output_path : str or Path, optional + Path to save transformed volume. + output_size : np.ndarray, optional + Size of output volume, uses input size by default (`None`). + order : int, default=1 + Order of interpolation. Returns ------- np.ndarray - Transformed volume data + Transformed volume data. Notes ----- @@ -256,18 +257,18 @@ def apply_transform_to_volume( return transformed -def make_affine(simpleITKImage: 'sitk.Image') -> np.ndarray: +def make_affine(simpleITKImage: 'sitk.Image') -> npt.NDArray[float]: """Create an affine transformation matrix from a SimpleITK image. Parameters ---------- simpleITKImage : sitk.Image - Input SimpleITK image + Input SimpleITK image. Returns ------- np.ndarray - 4x4 affine transformation matrix in RAS coordinates + 4x4 affine transformation matrix in RAS coordinates. Notes ----- @@ -277,10 +278,7 @@ def make_affine(simpleITKImage: 'sitk.Image') -> np.ndarray: 3. Returns the final 4x4 transformation matrix """ # get affine transform in LPS - c = [ - simpleITKImage.TransformContinuousIndexToPhysicalPoint(p) - for p in ((1, 0, 0), (0, 1, 0), (0, 0, 1), (0, 0, 0)) - ] + c = [simpleITKImage.TransformContinuousIndexToPhysicalPoint(p) for p in np.eye(4)[:, :3]] c = np.array(c) affine = np.concatenate( [np.concatenate([c[0:3] - c[3:], c[3:]], axis=0), [[0.0], [0.0], [0.0], [1.0]]], @@ -293,8 +291,8 @@ def make_affine(simpleITKImage: 'sitk.Image') -> np.ndarray: def map_softlabels_to_orig( - outputs_soft: np.ndarray, - orig_fsaverage_vox2vox: np.ndarray, + outputs_soft: npt.NDArray[float], + orig_fsaverage_vox2vox: npt.NDArray[float], orig: np.ndarray, slices_to_analyze: int, orig_space_segmentation_path: str | Path | None = None, @@ -306,24 +304,24 @@ def map_softlabels_to_orig( Parameters ---------- outputs_soft : np.ndarray - Soft label predictions + Soft label predictions. orig_fsaverage_vox2vox : np.ndarray - Original to fsaverage space transformation + Original to fsaverage space transformation. orig : np.ndarray - Original image + Original image. slices_to_analyze : int - Number of slices to analyze - orig_space_segmentation_path : str or Path or None, optional - Path to save segmentation in original space, by default None - fsaverage_middle : int, optional - Middle slice index in fsaverage space, by default 128 - subdivision_mask : np.ndarray or None, optional - Mask for subdividing regions, by default None + Number of slices to analyze. + orig_space_segmentation_path : str or Path, optional + Path to save segmentation in original space. + fsaverage_middle : int, default=128 + Middle slice index in fsaverage space. + subdivision_mask : np.ndarray, optional + Mask for subdividing regions. Returns ------- np.ndarray - Final segmentation in original image space + Final segmentation in original image space. Notes ----- @@ -337,17 +335,12 @@ def map_softlabels_to_orig( """ # map softlabels to original image + pad_lr = (fsaverage_middle - slices_to_analyze // 2, fsaverage_middle + slices_to_analyze // 2 + 1) + pad_tuples = (pad_lr,) + ((0, 0),) * (orig.ndim - 1) softlabels_transformed = [] for i in range(outputs_soft.shape[-1]): - # pad to original image size - outputs_soft_padded = np.zeros(orig.shape) - outputs_soft_padded[ - fsaverage_middle - - slices_to_analyze // 2 : fsaverage_middle - + slices_to_analyze // 2 - + 1 - ] = outputs_soft[..., i] + outputs_soft_padded = np.pad(outputs_soft[..., i], pad_tuples) s = affine_transform( outputs_soft_padded, @@ -361,15 +354,14 @@ def map_softlabels_to_orig( softlabels_orig_space = np.stack(softlabels_transformed, axis=-1) # apply softmax to softlabels_orig_space - softlabels_orig_space = np.exp(softlabels_orig_space) / np.sum( - np.exp(softlabels_orig_space), axis=-1, keepdims=True - ) + exp_orig_space = np.exp(softlabels_orig_space) + softlabels_orig_space = exp_orig_space / np.sum(exp_orig_space, axis=-1, keepdims=True) segmentation_orig_space = np.argmax(softlabels_orig_space, axis=-1) if subdivision_mask is not None: # repeat subdivision mask for shape 0 of orig - subdivision_mask = np.repeat(subdivision_mask[np.newaxis, :, :], orig.shape[0], axis=0) + subdivision_mask = np.repeat(subdivision_mask[np.newaxis], orig.shape[0], axis=0) # map subdivision mask to orig space subdivision_mask_orig_space = affine_transform( subdivision_mask, @@ -378,19 +370,11 @@ def map_softlabels_to_orig( order=0, ) - segmentation_orig_space[segmentation_orig_space == 1] = \ - segmentation_orig_space[segmentation_orig_space == 1] * \ - subdivision_mask_orig_space[segmentation_orig_space == 1] - - - segmentation_orig_space = np.where( - segmentation_orig_space == 1, 192, segmentation_orig_space - ) - segmentation_orig_space = np.where( - segmentation_orig_space == 2, 250, segmentation_orig_space - ) + mask = segmentation_orig_space == 1 + segmentation_orig_space[mask] *= subdivision_mask_orig_space[mask] - + seg_lut = np.asarray([0, 192, 250]) + segmentation_orig_space = seg_lut[segmentation_orig_space] if orig_space_segmentation_path is not None: logger.info(f"Saving segmentation in original space to {orig_space_segmentation_path}") diff --git a/CorpusCallosum/segmentation/segmentation_inference.py b/CorpusCallosum/segmentation/segmentation_inference.py index b0b6c0b6..8878df53 100644 --- a/CorpusCallosum/segmentation/segmentation_inference.py +++ b/CorpusCallosum/segmentation/segmentation_inference.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path import nibabel as nib import numpy as np @@ -116,8 +117,6 @@ def run_inference( - segmentation : Binary segmentation map - landmarks : Predicted landmark coordinates """ - orig_shape = image_slice.shape - if device is None: #device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = next(model.parameters()).device @@ -162,50 +161,26 @@ def crop_around_acpc(img: np.ndarray, # split into slices with 9 channels each # Generate views with sliding window of 9 slices batch_size, channels, height, width = inputs.shape - views = [] - for i in range(batch_size - 8): # -8 to ensure we have 9 slices per view - view = inputs[i:i+9] # Take 9 consecutive slices - view = view.reshape(1, 9*channels, height, width) # Reshape to combine slices into channels - views.append(view) - - inputs = torch.cat(views, dim=0) # Stack all views into batch dimension + inputs = inputs.unfold(0, 9, 1).swapdims(0, 1).reshape(-1, 9*channels, height, width) # Post-process outputs with torch.no_grad(): - scale_factors = torch.ones((inputs.shape[0], 2), device=device) * (1 / voxel_size) + scale_factors = torch.ones((inputs.shape[0], 2), device=device) / voxel_size outputs = model(inputs, scale_factor=scale_factors) # average the outputs along the batch dimension - outputs_avg = torch.mean(outputs, dim=0).unsqueeze(0) + outputs_avg = torch.mean(outputs, dim=0, keepdim=True) outputs_soft = outputs.cpu().numpy() #transforms.Activations(softmax=True)(outputs) # non_discrete outputs outputs = torch.stack([post_trans(i) for i in outputs]) outputs_avg = torch.stack([post_trans(i) for i in outputs_avg]) - pad_left, pad_right, pad_top, pad_bottom = to_pad - # Pad back to original size - outputs = np.pad(outputs, ((0,0), (0,0), (pad_left.item(), pad_right.item()), - (pad_top.item(), pad_bottom.item())), mode='constant', constant_values=0) - outputs_avg = np.pad(outputs_avg, ((0,0), (0,0), (pad_left.item(), pad_right.item()), - (pad_top.item(), pad_bottom.item())), mode='constant', constant_values=0) - outputs_soft = np.pad(outputs_soft, ((0,0), (0,0), (pad_left.item(), pad_right.item()), - (pad_top.item(), pad_bottom.item())), mode='constant', constant_values=0) - - # restore original shape - if orig_shape[-2:] != outputs.shape[-2:]: - new_outputs = np.zeros((outputs.shape[0], outputs.shape[1], orig_shape[-2], orig_shape[-1])) - new_outputs[:,:,:256,:256] = outputs - outputs = new_outputs - - new_outputs_avg = np.zeros((outputs_avg.shape[0], outputs_avg.shape[1], orig_shape[-2], orig_shape[-1])) - new_outputs_avg[:,:,:256,:256] = outputs_avg - outputs_avg = new_outputs_avg - - new_outputs_soft = np.zeros((outputs_soft.shape[0], outputs_soft.shape[1], - orig_shape[-2], orig_shape[-1]), dtype=np.float32) - new_outputs_soft[:,:,:256,:256] = outputs_soft - outputs_soft = new_outputs_soft + # Pad back to original size, to_pad is a tuple[int, int, int, int] + pad_tuples = ((0, 0),) * 2 + (to_pad[:2], to_pad[2:]) + outputs = np.pad(outputs, pad_tuples, mode='constant', constant_values=0) + outputs_avg = np.pad(outputs_avg, pad_tuples, mode='constant', constant_values=0) + outputs_soft = np.pad(outputs_soft, pad_tuples, mode='constant', constant_values=0) return ( outputs.transpose(0,2,3,1), @@ -216,6 +191,8 @@ def crop_around_acpc(img: np.ndarray, def load_validation_data(path): + from concurrent.futures import ThreadPoolExecutor + import pandas as pd data = pd.read_csv(path, index_col=0, header=None) data.columns = ["image", "label", "AC_center_x", "AC_center_y", "AC_center_z", @@ -227,21 +204,19 @@ def load_validation_data(path): labels = data["label"].values subj_ids = data.index.values.tolist() - label_widths = [] - for label_path in data['label']: - label_img =nib.load(label_path) + def _load(label_path: str | Path) -> int: + label_img = nib.load(label_path) if label_img.shape[0] > 100: # check which slices have non-zero values - label = label_img.get_fdata() - non_zero_slices = np.any(label > 0, axis=(1,2)) + label_data = np.asarray(label_img.dataobj) + non_zero_slices = np.any(label_data > 0, axis=(1,2)) first_nonzero = np.argmax(non_zero_slices) last_nonzero = len(non_zero_slices) - np.argmax(non_zero_slices[::-1]) - label_widths.append(last_nonzero - first_nonzero) + return last_nonzero - first_nonzero else: - label_widths.append(label_img.shape[0]) - - + return label_img.shape[0] + label_widths = ThreadPoolExecutor().map(_load, data["label"]) return images, ac_centers, pc_centers, label_widths, labels, subj_ids @@ -251,10 +226,7 @@ def one_hot_to_label(one_hot, label_ids=None): label_ids = [0, 192, 250] label = np.argmax(one_hot, axis=3) if label_ids is not None: - label = np.where(label == 0, label_ids[0], label) - label = np.where(label == 1, label_ids[1], label) - label = np.where(label == 2, label_ids[2], label) - + label = np.asarray(label_ids)[label] return label diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index 542e6f4a..f084eaaa 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +from numpy import typing as npt from scipy import integrate, ndimage from scipy.spatial.distance import cdist from skimage.measure import label @@ -23,7 +24,7 @@ logger = logging.get_logger(__name__) -def find_component_boundaries(labels_arr: np.ndarray, component_id: int) -> np.ndarray: +def find_component_boundaries(labels_arr: npt.NDArray[int], component_id: int) -> npt.NDArray[int]: """Find boundary voxels of a connected component. Parameters @@ -59,19 +60,19 @@ def find_component_boundaries(labels_arr: np.ndarray, component_id: int) -> np.n def find_minimal_connection_path( - boundary1: np.ndarray, - boundary2: np.ndarray, + boundary_coords1: np.ndarray, + boundary_coords2: np.ndarray, max_distance: float = 3.0 ) -> tuple[np.ndarray, np.ndarray] | None: """Find the minimal connection path between two component boundaries. Parameters ---------- - boundary1 : np.ndarray + boundary_coords1 : np.ndarray Boundary coordinates of first component, shape (N1, 3) - boundary2 : np.ndarray + boundary_coords2 : np.ndarray Boundary coordinates of second component, shape (N2, 3) - max_distance : float, optional + max_distance : float, default=3.0 Maximum distance to consider for connection, by default 3.0 Returns @@ -87,20 +88,18 @@ def find_minimal_connection_path( Uses Euclidean distance to find the closest pair of points between the two boundaries. """ - if len(boundary1) == 0 or len(boundary2) == 0: + if len(boundary_coords1) == 0 or len(boundary_coords2) == 0: return None # Calculate pairwise distances between all boundary points - distances = cdist(boundary1, boundary2, metric='euclidean') + distances = cdist(boundary_coords1, boundary_coords2, metric='euclidean') # Find the minimum distance and corresponding points min_idx = np.unravel_index(np.argmin(distances), distances.shape) min_distance = distances[min_idx] if min_distance <= max_distance: - point1 = boundary1[min_idx[0]] - point2 = boundary2[min_idx[1]] - return point1, point2 + return boundary_coords1[min_idx[0]], boundary_coords2[min_idx[1]] return None diff --git a/CorpusCallosum/shape/cc_mesh.py b/CorpusCallosum/shape/cc_mesh.py index 0824a6b2..dc6b7dad 100644 --- a/CorpusCallosum/shape/cc_mesh.py +++ b/CorpusCallosum/shape/cc_mesh.py @@ -24,7 +24,6 @@ import pyrr import scipy.interpolate from scipy.ndimage import gaussian_filter1d -from whippersnappy.core import snap1 import FastSurferCNN.utils.logging as logging from CorpusCallosum.shape.cc_endpoint_heuristic import smooth_contour @@ -1272,6 +1271,14 @@ def snap_cc_picture( 3. Cleans up temporary files after use. """ + try: + from whippersnappy.core import snap1 + except ImportError: + # whippersnappy not installed + raise RuntimeError( + "The snap_cc_picture method of CCMesh requires whippersnappy, but whippersnappy was not found. " + "Please install whippersnappy!" + ) from None self.__make_parent_folder(output_path) # Skip snapshot if there are no faces if len(self.t) == 0: diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/cc_postprocessing.py index e11e490b..f5ad2054 100644 --- a/CorpusCallosum/shape/cc_postprocessing.py +++ b/CorpusCallosum/shape/cc_postprocessing.py @@ -18,7 +18,7 @@ import numpy as np import FastSurferCNN.utils.logging as logging -from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE, SUBSEGEMNT_LABELS +from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE, SUBSEGMENT_LABELS from CorpusCallosum.data.read_write import run_in_background from CorpusCallosum.shape.cc_endpoint_heuristic import get_endpoints from CorpusCallosum.shape.cc_mesh import CC_Mesh @@ -628,7 +628,7 @@ def make_subdivision_mask( rows, cols = slice_shape y_coords, x_coords = np.mgrid[0:rows, 0:cols] - subsegment_labels_anterior_posterior = SUBSEGEMNT_LABELS.copy() + subsegment_labels_anterior_posterior = SUBSEGMENT_LABELS.copy() subsegment_labels_anterior_posterior.reverse() # Initialize with first segment label diff --git a/CorpusCallosum/shape/cc_thickness.py b/CorpusCallosum/shape/cc_thickness.py index 788101b2..bb6bf7ba 100644 --- a/CorpusCallosum/shape/cc_thickness.py +++ b/CorpusCallosum/shape/cc_thickness.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import meshpy.triangle as triangle import numpy as np import scipy.interpolate from lapy import Solver, TriaMesh from lapy.diffgeo import compute_rotated_f +from meshpy import triangle from CorpusCallosum.utils.utils import HiddenPrints diff --git a/CorpusCallosum/transforms/localization_transforms.py b/CorpusCallosum/transforms/localization_transforms.py index a02b7e89..f7fa3f0f 100644 --- a/CorpusCallosum/transforms/localization_transforms.py +++ b/CorpusCallosum/transforms/localization_transforms.py @@ -93,30 +93,23 @@ def __call__(self, data: dict) -> dict: center_point = ((ac_center + pc_center) / 2).astype(int) # Calculate voxel padding based on mm padding - voxel_padding_x = self.fixed_size[0] // 2 - voxel_padding_y = self.fixed_size[1] // 2 + voxel_padding = np.asarray(self.fixed_size) // 2 # Add random translation if specified if self.random_translate > 0: random_translate = np.random.randint(-self.random_translate, self.random_translate, size=2) else: - random_translate = (0,0) - - + random_translate = np.asarray((0, 0)) # Calculate crop boundaries with padding and random translation - crop_left = center_point[1] - voxel_padding_x + random_translate[0] - crop_right = center_point[1] + voxel_padding_x + random_translate[0] - crop_top = center_point[2] - voxel_padding_y + random_translate[1] - crop_bottom = center_point[2] + voxel_padding_y + random_translate[1] - + crops = center_point - voxel_padding + random_translate + # Ensure crop boundaries are within image - #img_shape = d['image'].shape[2:] # Get spatial dimensions - # crop_left = max(0, crop_left) - # crop_right = min(img_shape[0], crop_right) - # crop_top = max(0, crop_top) - # crop_bottom = min(img_shape[1], crop_bottom) + img_shape = np.asarray(d['image'].shape[2:]) # Get spatial dimensions + crops = np.maximum(0, np.minimum(img_shape, crops + np.asarray(self.fixed_size)) - np.asarray(self.fixed_size)) + crop_left, crop_top = crops.tolist() + crop_right, crop_bottom = (crops + np.asarray(self.fixed_size)).tolist() # raise error if crop boundaries are out of image if crop_left < 0 or crop_right > d['image'].shape[2] or crop_top < 0 or crop_bottom > d['image'].shape[3]: diff --git a/CorpusCallosum/transforms/segmentation_transforms.py b/CorpusCallosum/transforms/segmentation_transforms.py index 051bc5b4..0a11d9b7 100644 --- a/CorpusCallosum/transforms/segmentation_transforms.py +++ b/CorpusCallosum/transforms/segmentation_transforms.py @@ -26,12 +26,12 @@ class CropAroundACPC(RandomizableTransform, MapTransform): ---------- keys : list[str] Keys of the data dictionary to apply the transform to - allow_missing_keys : bool, optional - Whether to allow missing keys in the data dictionary, by default False - padding_mm : float, optional - Padding around AC-PC region in millimeters, by default 10 - random_translate : float, optional - Maximum random translation in voxels, by default 0 + allow_missing_keys : bool, default=False + Whether to allow missing keys in the data dictionary. + padding_mm : float, default=10.0 + Padding around AC-PC region in millimeters. + random_translate : float, default=0 + Maximum random translation in voxels, off by default. Notes ----- diff --git a/CorpusCallosum/visualization/visualization.py b/CorpusCallosum/visualization/visualization.py index 99027880..78e7c1a4 100644 --- a/CorpusCallosum/visualization/visualization.py +++ b/CorpusCallosum/visualization/visualization.py @@ -47,20 +47,10 @@ def plot_standardized_space( """ ax_row[0].set_title("Standardized") - # Axial view - ax_row[0].scatter(ac_coords[2], ac_coords[1], color="red", marker="x") - ax_row[0].scatter(pc_coords[2], pc_coords[1], color="blue", marker="x") - ax_row[0].imshow(vol[vol.shape[0] // 2], cmap="gray") - - # Sagittal view - ax_row[1].scatter(ac_coords[2], ac_coords[0], color="red", marker="x") - ax_row[1].scatter(pc_coords[2], pc_coords[0], color="blue", marker="x") - ax_row[1].imshow(vol[:, vol.shape[1] // 2], cmap="gray") - - # Coronal view - ax_row[2].scatter(ac_coords[1], ac_coords[0], color="red", marker="x") - ax_row[2].scatter(pc_coords[1], pc_coords[0], color="blue", marker="x") - ax_row[2].imshow(vol[:, :, vol.shape[2] // 2], cmap="gray") + for i, (a, b, _) in ((2, 1, "Axial"), (2, 0, "Sagittal"), (1, 0, "Coronal")): + ax_row[i].scatter(ac_coords[a], ac_coords[b], color="red", marker="x") + ax_row[i].scatter(pc_coords[a], pc_coords[b], color="blue", marker="x") + ax_row[i].imshow(vol[(slice(None),) * i + (vol.shape[i] // 2,)], cmap="gray") def visualize_coordinate_spaces( @@ -105,7 +95,11 @@ def visualize_coordinate_spaces( Notes ----- - Saves the visualization as 'ac_pc_spaces.png' in the output directory. + Saves a visualization of the anterior (red) and posterior (blue) commisure in three different view: + 1. the orig image (orig), + 2. fs-average standardized image space, and + 3. standardized image space + as a single image named 'ac_pc_spaces.png' in `output_dir`. """ fig, ax = plt.subplots(3, 4) ax = ax.T @@ -174,15 +168,14 @@ def plot_contours( """ # scale contour data by vox_size - split_contours = ( - [split_contour / vox_size for split_contour in split_contours] if split_contours is not None else None - ) - midline_equidistant = midline_equidistant / vox_size if midline_equidistant is not None else None - levelpaths = [levelpath / vox_size for levelpath in levelpaths] if levelpaths is not None else None + if split_contours: + split_contours = np.stack(split_contours, axis=0) / vox_size + if midline_equidistant: + midline_equidistant = midline_equidistant / vox_size + if levelpaths: + levelpaths = np.stack(levelpaths, axis=0) / vox_size - NO_PLOTS = 1 - if split_contours is not None: - NO_PLOTS += 1 + NO_PLOTS = 1 + int(split_contours is not None) _, ax = plt.subplots(1, NO_PLOTS, sharex=True, sharey=True, figsize=(15, 10)) @@ -256,9 +249,7 @@ def plot_midplane(grid_orig: np.ndarray, orig: np.ndarray) -> None: # Plot every 10th point to avoid overcrowding sample_idx = np.arange(0, grid_orig.shape[1], 40) - ax.scatter( - grid_orig[0, sample_idx], grid_orig[1, sample_idx], grid_orig[2, sample_idx], c="r", alpha=0.1, marker="." - ) + ax.scatter(*grid_orig[:3, sample_idx], c="r", alpha=0.1, marker=".") # Set labels ax.set_xlabel("X") diff --git a/FastSurferCNN/download_checkpoints.py b/FastSurferCNN/download_checkpoints.py index 7d7a620a..35492d79 100644 --- a/FastSurferCNN/download_checkpoints.py +++ b/FastSurferCNN/download_checkpoints.py @@ -112,15 +112,12 @@ def main( hypvinn: bool = False, cc: bool = False, all: bool = False, - files: list[str] = None, + files: list[str] = (), url: str | None = None, ) -> int | str: if not vinn and not files and not cerebnet and not hypvinn and not cc and not all: return ("Specify either files to download or --vinn, --cerebnet, " "--hypvinn, or --all, see help -h.") - - if files is None: - files = [] try: # FastSurferVINN checkpoints diff --git a/doc/conf.py b/doc/conf.py index f9ea8d7b..0921c6b2 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -12,8 +12,7 @@ import os from pathlib import Path -# here i added the relative path because sphinx was not able -# to locate FastSurferCNN module directly for autosummary +# relative path so sphinx can locate the different modules directly for autosummary sys.path.append(os.path.dirname(__file__) + "/..") sys.path.append(os.path.dirname(__file__) + "/../recon_surf") sys.path.append(os.path.dirname(__file__) + "/sphinx_ext") diff --git a/env/fastsurfer.yml b/env/fastsurfer.yml index 2913c0c7..7c73f303 100644 --- a/env/fastsurfer.yml +++ b/env/fastsurfer.yml @@ -7,6 +7,7 @@ dependencies: - h5py==3.12.1 - lapy==1.2.0 - matplotlib==3.10.1 +- monai==1.4.0 - nibabel==5.3.2 - numpy==1.26.4 - pandas==2.2.3 From 908daca0fef5ac969420a71c3fb36ec1189c4c83 Mon Sep 17 00:00:00 2001 From: David Kuegler Date: Wed, 12 Nov 2025 01:17:17 +0100 Subject: [PATCH 27/68] Fixes broken by history rewrite (merge => rebase) --- CorpusCallosum/data/constants.py | 9 +- CorpusCallosum/fastsurfer_cc.py | 106 +++++++++--------- .../localization/localization_inference.py | 31 ++--- CorpusCallosum/visualization/visualization.py | 19 ++-- 4 files changed, 76 insertions(+), 89 deletions(-) diff --git a/CorpusCallosum/data/constants.py b/CorpusCallosum/data/constants.py index 6cc2101a..c5b9e63d 100644 --- a/CorpusCallosum/data/constants.py +++ b/CorpusCallosum/data/constants.py @@ -15,15 +15,16 @@ from pathlib import Path +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT + ### Constants -WEIGHTS_PATH = Path(__file__).parent.parent.parent / "checkpoints" -FSAVERAGE_CENTROIDS_PATH = Path(__file__).parent / "fsaverage_centroids.json" -FSAVERAGE_DATA_PATH = Path(__file__).parent / "fsaverage_data.json" # Contains both affine and header +WEIGHTS_PATH = FASTSURFER_ROOT / "checkpoints" +FSAVERAGE_CENTROIDS_PATH = FASTSURFER_ROOT / "CorpusCallosum" / "fsaverage_centroids.json" +FSAVERAGE_DATA_PATH = FASTSURFER_ROOT / "CorpusCallosum" / "fsaverage_data.json" # Contains both affine and header FSAVERAGE_MIDDLE = 128 # Middle slice index in fsaverage space CC_LABEL = 192 # Label value for corpus callosum in segmentation FORNIX_LABEL = 250 # Label value for fornix in segmentation SUBSEGMENT_LABELS = [251, 252, 253, 254, 255] # labels for subsegments in segmentation -FASTSURFER_ROOT = Path(__file__).parent.parent.parent # TODO: use FastSurfer function for this STANDARD_INPUT_PATHS = { diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index d0208997..2e98ac64 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import argparse import json from pathlib import Path @@ -63,7 +64,7 @@ logger = logging.get_logger(__name__) SliceSelection = Literal["middle", "all"] | int SubdivisionMethod = Literal["shape", "vertical", "angular", "eigenvector"] - + def make_parser() -> argparse.ArgumentParser: """Create the argument parse object for the pipeline.""" @@ -104,7 +105,8 @@ def make_parser() -> argparse.ArgumentParser: parser.add_argument( "--num_thickness_points", - type=int, default=100, + type=int, + default=100, help="Number of points for thickness estimation." ) parser.add_argument( @@ -137,7 +139,7 @@ def _slice_selection(a: str) -> SliceSelection: return int(a) parser.add_argument( "--slice_selection", - type=str, + type=_slice_selection, default="all", help="Which slices to process. Options: 'middle', 'all', or a specific slice number. \ (default: 'all')", @@ -165,7 +167,8 @@ def _slice_selection(a: str) -> SliceSelection: default=None, ) advanced.add_argument( - "--segmentation_path", + "--seg", + dest="segmentation_path", type=Path, help=f"Path for segmentation output (default: subject_dir/{STANDARD_OUTPUT_PATHS['segmentation']})", default=None, @@ -512,37 +515,35 @@ def segment_cc( return segmentation, outputs_soft - - def main( - t1_path: str | Path, - aseg_path: str | Path, - output_dir: str | Path, - slice_selection: str = "middle", - qc_output_dir: str | Path = None, + conf_name: str | Path, + aseg_name: str | Path, + subject_dir: str | Path, + slice_selection: SliceSelection = "middle", + qc_output_dir: str | Path | None = None, verbose: bool = False, num_thickness_points: int = 100, subdivisions: list[float] | None = None, subdivision_method: SubdivisionMethod = "shape", contour_smoothing: float = 5, save_template_dir: str | Path | None = None, - device: str = "auto", - upright_volume_path: str | Path = None, - segmentation_path: str | Path = None, - postproc_results_path: str | Path = None, - cc_markers_path: str | Path = None, - upright_lta_path: str | Path = None, - orient_volume_lta_path: str | Path = None, - surf_file_path: str | Path = None, - overlay_file_path: str | Path = None, - cc_html_path: str | Path = None, - vtk_file_path: str | Path = None, - orig_space_segmentation_path: str | Path = None, - qc_image_path: str | Path = None, - thickness_image_path: str | Path = None, - softlabels_cc_path: str | Path = None, - softlabels_fn_path: str | Path = None, - softlabels_background_path: str | Path = None, + device: str | torch.device = "auto", + upright_volume_path: str | Path | None = None, + segmentation_path: str | Path | None = None, + postproc_results_path: str | Path | None = None, + cc_markers_path: str | Path | None = None, + upright_lta_path: str | Path | None = None, + orient_volume_lta_path: str | Path | None = None, + surf_file_path: str | Path | None = None, + overlay_file_path: str | Path | None = None, + cc_html_path: str | Path | None = None, + vtk_file_path: str | Path | None = None, + orig_space_segmentation_path: str | Path | None = None, + qc_image_path: str | Path | None = None, + thickness_image_path: str | Path | None = None, + softlabels_cc_path: str | Path | None = None, + softlabels_fn_path: str | Path | None = None, + softlabels_background_path: str | Path | None = None, ) -> None: """Main pipeline function for corpus callosum analysis. @@ -555,16 +556,16 @@ def main( Path to input MRI file. aseg_name : str or Path Path to input segmentation file. - output_dir : str or Path - Directory for output files. - slice_selection : str, optional - Which slices to process ('middle', 'all', or specific slice number), by default 'middle'. + subject_dir : str or Path + FastSurfer/FreeSurfer subject directory and directory for output files. + slice_selection : "middle", "all" or int, default="middle" + Which slices to process. qc_output_dir : str or Path, optional - Directory for quality control outputs, by default None. - verbose : bool, optional - Flag for verbose output, by default False. - num_thickness_points : int, optional - Number of points for thickness estimation, by default 100. + Directory for quality control outputs, None deactivates qc snapshots. + verbose : bool, default=False + Flag for verbose output. + num_thickness_points : int, default=100 + Number of points for thickness estimation. subdivisions : list[float], optional List of subdivision fractions for CC subsegmentation. subdivision_method : any of "shape", "vertical", "angular", "eigenvector", default="shape" @@ -572,10 +573,10 @@ def main( contour_smoothing : float, default=5 Gaussian sigma for smoothing during contour detection. save_template_dir : str or Path, optional - Directory path where to save contours.txt and thickness_values.txt files. \ - These files can be used to visualize the CC shape and volume in 3D. - device : str, optional - Device to run inference on ('auto', 'cpu', 'cuda', or 'cuda:X'), by default 'auto'. + Directory path where to save contours.txt and thickness_values.txt files. These files can be used to visualize + the CC shape and volume in 3D. Files are only saved, if a valid directory path is passed. + device : str, default="auto" + Device to run inference on ('auto', 'cpu', 'cuda', or 'cuda:X'). upright_volume_path : str or Path, optional Path to save upright volume. segmentation_path : str or Path, optional @@ -597,9 +598,9 @@ def main( vtk_file_path : str or Path, optional Path to save VTK file. orig_space_segmentation_path : str or Path, optional - Path to save segmentation in original space, by default None. + Path to save segmentation in original space. qc_image_path : str or Path, optional - Path to save QC images, by default None. + Path to save QC images. thickness_image_path : str or Path, optional Path to save thickness visualization. softlabels_cc_path : str or Path, optional @@ -636,15 +637,15 @@ def main( logging.setup_logging(None) # Log to stdout only logger.info("Starting corpus callosum analysis pipeline") - logger.info(f"Input MRI: {t1_path}") - logger.info(f"Input segmentation: {aseg_path}") - logger.info(f"Output directory: {output_dir}") + logger.info(f"Input MRI: {conf_name}") + logger.info(f"Input segmentation: {aseg_name}") + logger.info(f"Output directory: {subject_dir}") # Convert all paths to Path objects - t1_path = Path(t1_path) - aseg_path = Path(aseg_path) - if output_dir is not None: - output_dir = Path(output_dir) + t1_path = Path(conf_name) + aseg_path = Path(aseg_name) + if subject_dir: + subject_dir = Path(subject_dir) if save_template_dir: save_template_dir = Path(save_template_dir) @@ -675,10 +676,7 @@ def main( raise ValueError("MRI is not conformed, please run conform.py or mri_convert to conform the image.") # load models - if device == "auto": - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - else: - device = torch.device(device) + device = find_device(device) logger.info(f"Using device: {device}") logger.info("Loading models") diff --git a/CorpusCallosum/localization/localization_inference.py b/CorpusCallosum/localization/localization_inference.py index 105a57c0..4b8cf025 100644 --- a/CorpusCallosum/localization/localization_inference.py +++ b/CorpusCallosum/localization/localization_inference.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable from pathlib import Path import numpy as np @@ -20,34 +19,26 @@ from monai import transforms from monai.networks.nets import DenseNet -from CorpusCallosum.data import constants +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT from CorpusCallosum.transforms.localization_transforms import CropAroundACPCFixedSize from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults from FastSurferCNN.download_checkpoints import main as download_checkpoints - -def load_model(device: torch.device | None = None) -> DenseNet: +def load_model(device: torch.device) -> DenseNet: """Load trained numerical localization model from checkpoint. Parameters ---------- - checkpoint_path : str or Path or None, optional - Path to model checkpoint, by default None. - If None, downloads and uses default checkpoint. - device : torch.device or None, optional - Device to load model to, by default None. - If None, uses CUDA if available, else CPU. + device : torch.device + Device to load model to. Returns ------- DenseNet Loaded and initialized model in evaluation mode """ - if device is None or device == "auto": - from FastSurferCNN.utils.common import find_device - device = find_device(device) - + # Initialize model architecture (must match training) model = DenseNet( # densenet201 spatial_dims=2, @@ -64,10 +55,10 @@ def load_model(device: torch.device | None = None) -> DenseNet: download_checkpoints(cc=True) cc_config = load_checkpoint_config_defaults( - "checkpoint", - filename=CC_YAML, - ) - checkpoint_path = constants.FASTSURFER_ROOT / cc_config['localization'] + "checkpoint", + filename=CC_YAML, + ) + checkpoint_path = FASTSURFER_ROOT / cc_config['localization'] # Load state dict if isinstance(checkpoint_path, str) or isinstance(checkpoint_path, Path): @@ -141,7 +132,7 @@ def preprocess_volume( return transformed -def run_inference(model: DenseNet, +def run_inference(model: torch.nn.Module, image_volume: np.ndarray, third_ventricle_center: np.ndarray, device: torch.device | None = None, @@ -152,7 +143,7 @@ def run_inference(model: DenseNet, Parameters ---------- - model : torch.nn.Module + model : DenseNet Trained model for inference image_volume : np.ndarray Input volume as numpy array diff --git a/CorpusCallosum/visualization/visualization.py b/CorpusCallosum/visualization/visualization.py index 78e7c1a4..0d679373 100644 --- a/CorpusCallosum/visualization/visualization.py +++ b/CorpusCallosum/visualization/visualization.py @@ -95,8 +95,8 @@ def visualize_coordinate_spaces( Notes ----- - Saves a visualization of the anterior (red) and posterior (blue) commisure in three different view: - 1. the orig image (orig), + Saves a visualization of the anterior (red) and posterior (blue) commisure in three different view: + 1. the orig image (orig), 2. fs-average standardized image space, and 3. standardized image space as a single image named 'ac_pc_spaces.png' in `output_dir`. @@ -187,13 +187,10 @@ def plot_contours( ax[current_plot].imshow(transformed[transformed.shape[0] // 2], cmap="gray") # ax[0].imshow(cc_mask, cmap='autumn') ax[current_plot].set_title(title) - for i in range(len(split_contours)): - ax[current_plot].fill(split_contours[i][0, :], -split_contours[i][1, :], color="steelblue", alpha=0.25) - ax[current_plot].plot( - split_contours[i][0, :], -split_contours[i][1, :], color="mediumblue", linestyle="dotted", linewidth=0.7 - ) - - ax[current_plot].plot(split_contours[0][0, :], -split_contours[0][1, :], color="mediumblue", linewidth=0.7) + for i, this_contour in enumerate(split_contours): + ax[current_plot].fill(this_contour[0, :], -this_contour[1, :], color="steelblue", alpha=0.25) + kwargs = {"color": "mediumblue", "linewidth": 0.7, "linestyle": "solid" if i != 0 else "dotted"} + ax[current_plot].plot(this_contour[0, :], -this_contour[1, :], **kwargs) ax[current_plot].scatter(ac_coords[1], ac_coords[0], color="red", marker="x") ax[current_plot].scatter(pc_coords[1], pc_coords[0], color="blue", marker="x") current_plot += 1 @@ -202,8 +199,8 @@ def plot_contours( ax[current_plot].imshow(transformed[transformed.shape[0] // 2], cmap="gray") # ax[2].imshow(cc_mask, cmap='autumn') - for i in range(len(levelpaths)): - ax[current_plot].plot(levelpaths[i][:, 0], -levelpaths[i][:, 1], color="brown", linewidth=0.8) + for this_path in levelpaths: + ax[current_plot].plot(this_path[:, 0], -this_path[:, 1], color="brown", linewidth=0.8) ax[current_plot].set_title("Midline & Levelpaths") ax[current_plot].plot(midline_equidistant[:, 0], -midline_equidistant[:, 1], color="red") ax[current_plot].plot(reference_contour[0, :], -reference_contour[1, :], color="red", linewidth=0.5) From 3dbf31da3f7c9bc566a125ddd8c56ad6e4c80f7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Fri, 14 Nov 2025 11:34:04 +0100 Subject: [PATCH 28/68] Fixing problems introduced by incomplete changes in review resolving several issues from the review, like using concurrent.futures.* Cleanup, optimizations and formatting (e.g. variable names) --- CorpusCallosum/cc_visualization.py | 4 +- CorpusCallosum/data/constants.py | 8 +- CorpusCallosum/data/fsaverage_cc_template.py | 22 +- CorpusCallosum/data/read_write.py | 142 +--- CorpusCallosum/fastsurfer_cc.py | 490 +++++++------ .../localization/localization_inference.py | 76 +- .../registration/mapping_helpers.py | 39 +- .../segmentation/segmentation_inference.py | 149 ++-- .../segmentation_postprocessing.py | 44 +- CorpusCallosum/shape/cc_endpoint_heuristic.py | 2 - CorpusCallosum/shape/cc_mesh.py | 142 ++-- CorpusCallosum/shape/cc_metrics.py | 19 +- CorpusCallosum/shape/cc_postprocessing.py | 656 +++++++++--------- CorpusCallosum/shape/cc_subsegment_contour.py | 33 +- CorpusCallosum/shape/cc_thickness.py | 63 +- .../transforms/localization_transforms.py | 101 +-- recon_surf/align_points.py | 12 +- 17 files changed, 929 insertions(+), 1073 deletions(-) diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index 3481c3cb..c1df96f8 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -8,7 +8,7 @@ from CorpusCallosum.data.constants import FSAVERAGE_DATA_PATH from CorpusCallosum.data.fsaverage_cc_template import load_fsaverage_cc_template from CorpusCallosum.data.read_write import load_fsaverage_data -from CorpusCallosum.shape.cc_mesh import CC_Mesh +from CorpusCallosum.shape.cc_mesh import CCMesh def make_parser() -> argparse.ArgumentParser: @@ -110,7 +110,7 @@ def main( output_dir = Path(output_dir) # Load data and create mesh - cc_mesh = CC_Mesh(num_slices=1) # Will be resized when loading data + cc_mesh = CCMesh(num_slices=1) # Will be resized when loading data _, _, vox2ras_tkr = load_fsaverage_data(FSAVERAGE_DATA_PATH) diff --git a/CorpusCallosum/data/constants.py b/CorpusCallosum/data/constants.py index c5b9e63d..321bd8da 100644 --- a/CorpusCallosum/data/constants.py +++ b/CorpusCallosum/data/constants.py @@ -13,14 +13,14 @@ # limitations under the License. -from pathlib import Path from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT ### Constants WEIGHTS_PATH = FASTSURFER_ROOT / "checkpoints" -FSAVERAGE_CENTROIDS_PATH = FASTSURFER_ROOT / "CorpusCallosum" / "fsaverage_centroids.json" -FSAVERAGE_DATA_PATH = FASTSURFER_ROOT / "CorpusCallosum" / "fsaverage_data.json" # Contains both affine and header +FSAVERAGE_CENTROIDS_PATH = FASTSURFER_ROOT / "CorpusCallosum" / "data" / "fsaverage_centroids.json" +# Contains both affine and header +FSAVERAGE_DATA_PATH = FASTSURFER_ROOT / "CorpusCallosum" / "data" / "fsaverage_data.json" FSAVERAGE_MIDDLE = 128 # Middle slice index in fsaverage space CC_LABEL = 192 # Label value for corpus callosum in segmentation FORNIX_LABEL = 250 # Label value for fornix in segmentation @@ -28,7 +28,7 @@ STANDARD_INPUT_PATHS = { - "t1": "mri/orig.mgz", + "conf_name": "mri/orig.mgz", "aseg_name": "mri/aparc.DKTatlas+aseg.deep.mgz", } diff --git a/CorpusCallosum/data/fsaverage_cc_template.py b/CorpusCallosum/data/fsaverage_cc_template.py index 497c52ed..e0106421 100644 --- a/CorpusCallosum/data/fsaverage_cc_template.py +++ b/CorpusCallosum/data/fsaverage_cc_template.py @@ -20,7 +20,7 @@ from scipy import ndimage from CorpusCallosum.data import constants -from CorpusCallosum.shape.cc_postprocessing import process_slice +from CorpusCallosum.shape.cc_postprocessing import recon_cc_surf_measure from FastSurferCNN.utils.brainvolstats import mask_in_array @@ -121,16 +121,16 @@ def load_fsaverage_cc_template() -> tuple[ cc_mask = cc_mask_smoothed.astype(int) * 192 (_, contour_with_thickness, anterior_endpoint_idx, - posterior_endpoint_idx) = process_slice(segmentation=cc_mask[None], - slice_idx=0, - ac_coords=AC, - pc_coords=PC, - affine=fsaverage_seg.affine, - num_thickness_points=100, - subdivisions=[1/6, 1/2, 2/3, 3/4], - subdivision_method="shape", - contour_smoothing=5, - vox_size=1) + posterior_endpoint_idx) = recon_cc_surf_measure(segmentation=cc_mask[None], + slice_idx=0, + ac_coords=AC, + pc_coords=PC, + affine=fsaverage_seg.affine, + num_thickness_points=100, + subdivisions=[1/6, 1/2, 2/3, 3/4], + subdivision_method="shape", + contour_smoothing=5, + vox_size=1) outside_contour = contour_with_thickness[0].T diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py index 00442944..01b354d7 100644 --- a/CorpusCallosum/data/read_write.py +++ b/CorpusCallosum/data/read_write.py @@ -13,75 +13,44 @@ # limitations under the License. import json -import multiprocessing from pathlib import Path -from typing import overload +from typing import TypedDict import nibabel as nib import numpy as np +from numpy import typing as npt import FastSurferCNN.utils.logging as logging -logger = logging.get_logger(__name__) - - -def run_in_background(function: callable, debug: bool = False, *args, **kwargs) -> multiprocessing.Process | None: - """Run a function in the background using multiprocessing. - - Parameters - ---------- - function : callable - The function to execute. - debug : bool, optional - If True, run synchronously in current process, by default False. - *args - Positional arguments to pass to the function. - **kwargs - Keyword arguments to pass to the function. - Returns - ------- - multiprocessing.Process or None - Process object if running in background, None if in debug mode. - """ - if debug: - function(*args, **kwargs) - process = None - else: - process = multiprocessing.Process(target=function, args=args, kwargs=kwargs) - process.start() - return process +class FSAverageHeader(TypedDict): + dims: npt.NDArray[int] + delta: npt.NDArray[float] + Mdc: npt.NDArray[float] + Pxyz_c: npt.NDArray[float] +logger = logging.get_logger(__name__) -@overload -def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: None = None) -> dict[int, np.ndarray]: - ... - -@overload -def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: list[int]) -> tuple[dict[int, np.ndarray], list[int]]: - ... -def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: list[int] | None = None): +def get_centroids_from_nib(seg_img: nib.analyze.SpatialImage, label_ids: list[int] | None = None) \ + -> dict[int, np.ndarray | None]: """Get centroids of segmentation labels in RAS coordinates. Parameters ---------- - seg_img : nibabel.Nifti1Image + seg_img : nibabel.analyze.SpatialImage Input segmentation image. label_ids : list[int], optional List of label IDs to extract centroids for. If None, extracts all non-zero labels. Returns ------- - dict[int, np.ndarray] - If label_ids is None, returns a dict mapping label IDs to their centroids (x,y,z) in RAS coordinates. - If label_ids is provided, returns a tuple containing: - - dict[int, np.ndarray]: Mapping of found label IDs to their centroids. - - list[int]: List of label IDs that were not found in the image. + dict[int, np.ndarray | None] + A dict mapping label IDs to their centroids (x,y,z) in RAS coordinates, None if label did not exist. """ # Get segmentation data and affine - seg_data = seg_img.get_fdata() - vox2ras = seg_img.affine + seg_data: npt.NDArray[np.integer] = np.asarray(seg_img.dataobj) + vox2ras: npt.NDArray[float] = seg_img.affine # Get unique labels if label_ids is None: @@ -90,61 +59,23 @@ def get_centroids_from_nib(seg_img: nib.Nifti1Image, label_ids: list[int] | None else: labels = label_ids - centroids = {} - ids_not_found = [] - for label in labels: - # Get voxel indices for this label - vox_coords = np.array(np.where(seg_data == label)) - if vox_coords.size == 0: - ids_not_found.append(label) - continue + def _calc_ras_centroid(mask_vox: npt.NDArray[np.integer]) -> npt.NDArray[float]: # Calculate centroid in voxel space - vox_centroid = np.mean(vox_coords, axis=1) - + vox_centroid = np.mean(mask_vox, axis=1, dtype=float) + # Convert to homogeneous coordinates vox_centroid = np.append(vox_centroid, 1) - - # Transform to RAS coordinates - ras_centroid = vox2ras @ vox_centroid - - # Store without homogeneous coordinate - centroids[int(label)] = ras_centroid[:3] - - if label_ids is not None: - return centroids, ids_not_found - else: - return centroids - + # Transform to RAS coordinates and return without homogeneous coordinate + return (vox2ras @ vox_centroid)[:3] -def save_nifti_background( - io_processes: list, - data: np.ndarray, - affine: np.ndarray, - header: nib.Nifti1Header, - filepath: str | Path -) -> None: - """Save a NIfTI image in a background process. - - Creates a MGHImage from the provided data and metadata, then saves it to disk - using a background process to avoid blocking the main execution. - - Parameters - ---------- - io_processes : list - List to store background process handles. - data : np.ndarray - Image data array. - affine : np.ndarray - 4x4 affine transformation matrix. - header : nib.Nifti1Header - NIfTI header object containing metadata. - filepath : str or Path - Path where the image should be saved. - """ - logger.info(f"Saving NIfTI image to {filepath}") - io_processes.append(run_in_background(nib.save, False, - nib.MGHImage(data, affine, header), filepath)) + centroids = {} + for label in labels: + # Get voxel indices for this label + vox_coords = np.array(np.where(seg_data == label)) + centroids[int(label)] = None if vox_coords.size == 0 else _calc_ras_centroid(vox_coords) + + return centroids def convert_numpy_to_json_serializable(obj: object) -> object: @@ -173,7 +104,7 @@ def convert_numpy_to_json_serializable(obj: object) -> object: return obj -def load_fsaverage_centroids(centroids_path: str | Path) -> dict[int, np.ndarray]: +def load_fsaverage_centroids(centroids_path: str | Path) -> dict[int, npt.NDArray[float]]: """Load fsaverage centroids from static JSON file. Parameters @@ -198,7 +129,7 @@ def load_fsaverage_centroids(centroids_path: str | Path) -> dict[int, np.ndarray return {int(label): np.array(centroid) for label, centroid in centroids_data.items()} -def load_fsaverage_affine(affine_path: str | Path) -> np.ndarray: +def load_fsaverage_affine(affine_path: str | Path) -> npt.NDArray[float]: """Load fsaverage affine matrix from static text file. Parameters @@ -216,7 +147,7 @@ def load_fsaverage_affine(affine_path: str | Path) -> np.ndarray: if not affine_path.exists(): raise FileNotFoundError(f"Fsaverage affine file not found: {affine_path}") - affine_matrix = np.loadtxt(affine_path) + affine_matrix = np.loadtxt(affine_path).astype(float) if affine_matrix.shape != (4, 4): raise ValueError(f"Expected 4x4 affine matrix, got shape {affine_matrix.shape}") @@ -224,7 +155,7 @@ def load_fsaverage_affine(affine_path: str | Path) -> np.ndarray: return affine_matrix -def load_fsaverage_data(data_path: str | Path) -> tuple[np.ndarray, dict, np.ndarray]: +def load_fsaverage_data(data_path: str | Path) -> tuple[npt.NDArray[float], FSAverageHeader, npt.NDArray[float]]: """Load fsaverage affine matrix and header fields from static JSON file. Parameters @@ -257,9 +188,7 @@ def load_fsaverage_data(data_path: str | Path) -> tuple[np.ndarray, dict, np.nda If the file is not valid JSON. ValueError If required fields are missing. - """ - data_path = Path(data_path) if not data_path.exists(): raise FileNotFoundError(f"Fsaverage data file not found: {data_path}") @@ -281,9 +210,12 @@ def load_fsaverage_data(data_path: str | Path) -> tuple[np.ndarray, dict, np.nda # Convert lists back to numpy arrays affine_matrix = np.array(data["affine"]) vox2ras_tkr = np.array(data["vox2ras_tkr"]) - header_data = data["header"].copy() - header_data["Mdc"] = np.array(header_data["Mdc"]) - header_data["Pxyz_c"] = np.array(header_data["Pxyz_c"]) + header_data = FSAverageHeader( + dims=data["header"]["dims"], + delta=data["header"]["delta"], + Mdc=np.array(data["header"]["Mdc"]), + Pxyz_c=np.array(data["header"]["Pxyz_c"]), + ) # Validate affine matrix shape if affine_matrix.shape != (4, 4): diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 2e98ac64..de458eea 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -16,16 +16,14 @@ import argparse import json from pathlib import Path -from typing import Literal +from typing import Literal, cast -# import warnings warnings.filterwarnings("ignore", message="TypedStorage is deprecated") import nibabel as nib import numpy as np import torch from monai.networks.nets import DenseNet from numpy import typing as npt -import FastSurferCNN.utils.logging as logging from CorpusCallosum.data.constants import ( CC_LABEL, FSAVERAGE_CENTROIDS_PATH, @@ -35,73 +33,46 @@ STANDARD_OUTPUT_PATHS, ) from CorpusCallosum.data.read_write import ( + FSAverageHeader, convert_numpy_to_json_serializable, get_centroids_from_nib, load_fsaverage_centroids, load_fsaverage_data, - run_in_background, - save_nifti_background, ) from CorpusCallosum.localization import localization_inference from CorpusCallosum.registration.mapping_helpers import ( apply_transform_to_pt, apply_transform_to_volume, - get_mapping_to_standard_space, + calc_mapping_to_standard_space, interpolate_midplane, map_softlabels_to_orig, ) from CorpusCallosum.segmentation import segmentation_inference, segmentation_postprocessing from CorpusCallosum.shape.cc_postprocessing import ( + SubdivisionMethod, check_area_changes, make_subdivision_mask, - process_slices, + recon_cc_surf_measures_multi, ) from FastSurferCNN.data_loader.conform import is_conform -from FastSurferCNN.utils.common import find_device +from FastSurferCNN.utils import logging +from FastSurferCNN.utils.common import SubjectDirectory, find_device +from FastSurferCNN.utils.common import thread_executor as executor from recon_surf import lta from recon_surf.align_points import find_rigid logger = logging.get_logger(__name__) SliceSelection = Literal["middle", "all"] | int -SubdivisionMethod = Literal["shape", "vertical", "angular", "eigenvector"] def make_parser() -> argparse.ArgumentParser: """Create the argument parse object for the pipeline.""" + from FastSurferCNN.utils.parser_defaults import add_arguments + parser = argparse.ArgumentParser() # Specify subject directory + subject ID, OR specify individual MRI and segmentation files + output paths - mgroup = parser.add_mutually_exclusive_group() - mgroup.add_argument( - "--sd", - type=Path, - help="Root directory in which the case directory is located. " - "Must be used together with --sid.", - ) - parser.add_argument( - "--sid", - type=str, - help="Name of the case directory. Must be used together with --sd.", - ) - mgroup.add_argument( - "--t1", - type=Path, - help=f"Input MRI file path. Must be used together with --aseg_name. \ - (default: subject_dir/{STANDARD_INPUT_PATHS['t1']})", - ) - parser.add_argument( - "--aseg_name", - type=Path, - help=f"Input segmentation file path. Must be used together with --t1. \ - (default: subject_dir/{STANDARD_INPUT_PATHS['aseg_name']})", - ) - parser.add_argument( - "--device", - type=str, - default="auto", - help="Select device to run inference on: cpu, or cuda (= Nvidia gpu) or specify a certain gpu (e.g. cuda:1), \ - Default: auto", - ) + add_arguments(parser, ["sd", "sid", "conformed_name", "aseg_name", "device"]) parser.add_argument( "--num_thickness_points", @@ -112,18 +83,19 @@ def make_parser() -> argparse.ArgumentParser: parser.add_argument( "--subdivisions", type=float, - nargs="+", + metavar="FRAC", + nargs=4, default=[1/6, 1/2, 2/3, 3/4], - help="List of subdivision fractions for the corpus callosum subsegmentation.", + help="List of FOUR subdivision fractions for the corpus callosum subsegmentation.", ) parser.add_argument( "--subdivision_method", default="shape", help="Method for contour subdivision. \ - Options: shape (Intercallosal subdivision perpendicular to intercallosal line), vertical \ - (orthogonal to the most anterior and posterior points in the AC/PC standardized CC contour), \ - angular (subdivision based on equally spaced angles, as proposed by Hampel and colleagues), \ - eigenvector (primary direction, same as FreeSurfers mri_cc)", + Options: shape (Intercallosal subdivision perpendicular to intercallosal line), vertical \ + (orthogonal to the most anterior and posterior points in the AC/PC standardized CC contour), \ + angular (subdivision based on equally spaced angles, as proposed by Hampel and colleagues), \ + eigenvector (primary direction, same as FreeSurfers mri_cc)", choices=["shape", "vertical", "angular", "eigenvector"], ) parser.add_argument( @@ -145,16 +117,19 @@ def _slice_selection(a: str) -> SliceSelection: (default: 'all')", ) parser.add_argument( + "-v", "--verbose", - action="store_true", - help="Enable verbose (shows output paths)", - default=False, + action="count", + default=0, + help="Enable verbose (pass twice for debug-output).", ) ######## OUTPUT PATHS ######### # 4. Options for advanced, technical parameters - advanced = parser.add_argument_group(title="Advanced options", - description="Custom output paths, useful if no standard case directory is used.") + advanced = parser.add_argument_group( + title="Advanced options", + description="Custom output paths, useful if no standard case directory is used.", + ) advanced.add_argument("--qc_output_dir", type=Path, required=False, @@ -274,8 +249,6 @@ def _slice_selection(a: str) -> SliceSelection: default=None, ) ############ END OF OUTPUT PATHS ############ - - return parser @@ -287,11 +260,11 @@ def options_parse() -> argparse.Namespace: # Reconstruct subject_dir from sd and sid (but sd might be stored as out_dir by parser_defaults) sd_value = getattr(args, 'sd', getattr(args, 'out_dir', None)) if sd_value and hasattr(args, 'sid') and args.sid: - args.subject_dir = str(Path(sd_value) / args.sid) + args.subject_dir = Path(sd_value) / args.sid else: args.subject_dir = None - # Validation logic: must use either directory approach (--sd + --sid) OR file approach (--t1 + --aseg_name) + # Validation logic: must use either directory approach (--sd + --sid) OR file approach (--conf_name + --aseg_name) if sd_value: # Using directory approach - make sure sid was also provided if not (hasattr(args, 'sid') and args.sid): @@ -300,28 +273,28 @@ def options_parse() -> argparse.Namespace: # If sid is provided without sd, that's an error if not sd_value: parser.error("When using --sid, you must also provide --sd.") - elif hasattr(args, 't1') and args.t1: + elif hasattr(args, 'conf_name') and args.conf_name: # Using file approach - make sure aseg_name was also provided if not (hasattr(args, 'aseg_name') and args.aseg_name): - parser.error("When using --t1, you must also provide --aseg_name.") + parser.error("When using --conf_name, you must also provide --aseg_name.") elif hasattr(args, 'aseg_name') and args.aseg_name: - # If aseg_name is provided without t1, that's an error - if not (hasattr(args, 't1') and args.t1): - parser.error("When using --aseg_name, you must also provide --t1.") + # If aseg_name is provided without conf_name, that's an error + if not (hasattr(args, 'conf_name') and args.conf_name): + parser.error("When using --aseg_name, you must also provide --conf_name.") else: - parser.error("You must specify either --sd and --sid OR both --t1 and --aseg_name.") + parser.error("You must specify either --sd and --sid OR both --conf_name and --aseg_name.") # If subject_dir is provided, set default paths for missing arguments if args.subject_dir: - subject_dir_path = Path(args.subject_dir) + subject_dir_path = args.subject_dir # Create standard FreeSurfer subdirectories (subject_dir_path / "mri").mkdir(parents=True, exist_ok=True) (subject_dir_path / "stats").mkdir(parents=True, exist_ok=True) (subject_dir_path / "transforms").mkdir(parents=True, exist_ok=True) - if not args.t1: - args.t1 = str(subject_dir_path / STANDARD_INPUT_PATHS["t1"]) + if not args.conf_name: + args.conf_name = str(subject_dir_path / STANDARD_INPUT_PATHS["conf_name"]) if not args.aseg_name: args.aseg_name = str(subject_dir_path / STANDARD_INPUT_PATHS["aseg_name"]) @@ -343,8 +316,8 @@ def options_parse() -> argparse.Namespace: return args -def centroid_registration(aseg_nib: nib.Nifti1Image) -> tuple[ - npt.NDArray[float], npt.NDArray[float], npt.NDArray[float], nib.Nifti1Header, npt.NDArray[float] +def centroid_registration(aseg_nib: nib.analyze.SpatialImage) -> tuple[ + npt.NDArray[float], npt.NDArray[float], npt.NDArray[float], FSAverageHeader, npt.NDArray[float] ]: """Perform centroid-based registration between subject and fsaverage space. @@ -353,7 +326,7 @@ def centroid_registration(aseg_nib: nib.Nifti1Image) -> tuple[ Parameters ---------- - aseg_nib : nibabel.Nifti1Image + aseg_nib : nibabel.analyze.SpatialImage Subject's segmentation image. Returns @@ -364,7 +337,7 @@ def centroid_registration(aseg_nib: nib.Nifti1Image) -> tuple[ Transformation matrix from original to fsaverage RAS space. fsaverage_hires_affine : np.ndarray High-resolution fsaverage affine matrix. - fsaverage_header : nibabel.Nifti1Header + fsaverage_header : FSAverageHeader FSAverage header fields for LTA writing. vox2ras_tkr : np.ndarray Voxel to RAS tkr-space transformation matrix. @@ -381,35 +354,34 @@ def centroid_registration(aseg_nib: nib.Nifti1Image) -> tuple[ centroids_dst = load_fsaverage_centroids(FSAVERAGE_CENTROIDS_PATH) fsaverage_affine, fsaverage_header, vox2ras_tkr = load_fsaverage_data(FSAVERAGE_DATA_PATH) - centroids_mov, ids_not_found = get_centroids_from_nib(aseg_nib, label_ids=list(centroids_dst.keys())) + centroids_mov = get_centroids_from_nib(aseg_nib, label_ids=list(centroids_dst.keys())) - # delete not found labels from centroids_mov - for id in ids_not_found: - del centroids_dst[id] + # get the set of joint labels + joint_centroid_labels = [lbl for lbl, v in centroids_mov.items() if v is not None] - centroids_mov = np.array(list(centroids_mov.values())).T - centroids_dst = np.array(list(centroids_dst.values())).T + centroids_mov = np.array([centroids_mov[lbl] for lbl in joint_centroid_labels]).T + centroids_dst = np.array([centroids_dst[lbl] for lbl in joint_centroid_labels]).T - orig_fsaverage_ras2ras = find_rigid(p_mov=centroids_mov.T, p_dst=centroids_dst.T) + orig_fsaverage_ras2ras: npt.NDArray[float] = find_rigid(p_mov=centroids_mov.T, p_dst=centroids_dst.T) # make affine that increases resolution to orig resolution - resolution_trans = np.diagflat(list(aseg_nib.header.get_zooms()[:3]) + [1]) + resolution_trans: npt.NDArray[float] = np.diagflat(list(aseg_nib.header.get_zooms()[:3]) + [1]).astype(float) - orig_fsaverage_vox2vox = ( + orig_fsaverage_vox2vox: npt.NDArray[float] = ( np.linalg.inv(resolution_trans @ fsaverage_affine) @ orig_fsaverage_ras2ras @ aseg_nib.affine ) - fsaverage_hires_affine = resolution_trans @ fsaverage_affine + fsaverage_hires_affine: npt.NDArray[float] = resolution_trans @ fsaverage_affine logger.info("Centroid registration successful!") return orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, fsaverage_hires_affine, fsaverage_header, vox2ras_tkr def localize_ac_pc( midslices: np.ndarray, - aseg_nib: "nib.Nifti1Image", + aseg_nib: nib.analyze.SpatialImage, orig_fsaverage_vox2vox: npt.NDArray[float], model_localization: DenseNet, - slices_to_analyze: int -) -> tuple[np.ndarray, np.ndarray]: + num_slices_to_analyze: int +) -> tuple[npt.NDArray[float], npt.NDArray[float]]: """Localize anterior and posterior commissure points in the brain. Uses a trained model to detect AC and PC points in mid-sagittal slices, @@ -419,13 +391,13 @@ def localize_ac_pc( ---------- midslices : np.ndarray Array of mid-sagittal slices. - aseg_nib : nibabel.Nifti1Image - Subject's segmentation image. + aseg_nib : nibabel.analyze.SpatialImage + Subject's segmentation image in native subject space. orig_fsaverage_vox2vox : np.ndarray - Transformation matrix to fsaverage space. + Transformation matrix from subject/native space to fsaverage space (in lia). model_localization : DenseNet Trained model for AC-PC detection. - slices_to_analyze : int + num_slices_to_analyze : int Number of slices to process. Returns @@ -442,12 +414,10 @@ def localize_ac_pc( third_ventricle_center_vox = apply_transform_to_pt(third_ventricle_center, orig_fsaverage_vox2vox, inv=False) # get 5 mm of slices output with 3 slices per inference - midslices_middle = midslices.shape[0] // 2 - middle_slices_localization = midslices[ - midslices_middle - slices_to_analyze // 2 - 1 : midslices_middle + slices_to_analyze // 2 + 2 - ] + midslices_start = midslices.shape[0] // 2 - num_slices_to_analyze // 2 - 1 + middle_slices_localization = midslices[midslices_start:midslices_start + num_slices_to_analyze + 3] ac_coords, pc_coords = localization_inference.run_inference_on_slice( - model_localization, middle_slices_localization, third_ventricle_center_vox[1:] + model_localization, middle_slices_localization, third_ventricle_center_vox[1:], ) return ac_coords, pc_coords @@ -463,9 +433,8 @@ def segment_cc( ) -> tuple[npt.NDArray[bool], npt.NDArray[float]]: """Segment the corpus callosum using a trained model. - Performs corpus callosum segmentation on mid-sagittal slices using a trained model, - with AC-PC points as anatomical references. Includes post-processing to clean - the segmentation. + Performs corpus callosum segmentation on mid-sagittal slices using a trained model, with AC-PC points as anatomical + references. Includes post-processing to clean the cc_seg_labels. Parameters ---------- @@ -476,43 +445,42 @@ def segment_cc( pc_coords : np.ndarray Posterior commissure coordinates. aseg_nib : nibabel.Nifti1Image - Subject's segmentation image. + Subject's cc_seg_labels image. model_segmentation : torch.nn.Module - Trained model for CC segmentation. + Trained model for CC cc_seg_labels. slices_to_analyze : int Number of slices to process. Returns ------- - segmentation : np.ndarray - Binary segmentation of the corpus callosum. - outputs_soft : np.ndarray - Soft segmentation probabilities. + cc_seg_labels : np.ndarray + Binary cc_seg_labels of the corpus callosum. + cc_softlabels : np.ndarray + Soft cc_seg_labels probabilities. """ # get 5 mm of slices output with 9 slices per inference midslices_start = midslices.shape[0] // 2 - slices_to_analyze // 2 - 4 middle_slices_slab = midslices[midslices_start:midslices_start + slices_to_analyze + 9] - pre_clean_segmentation, inputs, outputs_avg, outputs_soft = segmentation_inference.run_inference_on_slice( + pre_clean_segmentation, inputs, cc_softlabels = segmentation_inference.run_inference_on_slice( model_segmentation, middle_slices_slab, - AC_center=ac_coords, - PC_center=pc_coords, + ac_center=ac_coords, + pc_center=pc_coords, voxel_size=aseg_nib.header.get_zooms()[0], ) - segmentation, cc_volume_mask = segmentation_postprocessing.clean_cc_segmentation(pre_clean_segmentation) + cc_seg_labels, cc_volume_mask = segmentation_postprocessing.clean_cc_segmentation(pre_clean_segmentation) # print a warning if the cc_volume_mask touches the edge of the segmentation - if ( - np.any(cc_volume_mask[:, [0, -1]]) - or np.any(cc_volume_mask[:, :, [0, -1]]) - ): - logger.warning("CC volume mask touches the edge of the segmentation field-of-view, CC might be truncated") + if np.any(cc_volume_mask[:, [0, -1]]) or np.any(cc_volume_mask[:, :, [0, -1]]): + logger.warning("CC volume mask touches the edge of the cc_seg_labels field-of-view, CC might be truncated") # get voxels that were removed during cleaning - outputs_soft[pre_clean_segmentation != segmentation, 1] = 0 + cleaned_mask = pre_clean_segmentation != cc_seg_labels + cc_softlabels[cleaned_mask, 1] = 0 + cc_softlabels[cleaned_mask, :] /= np.sum(cc_softlabels[cleaned_mask, :], axis=-1, keepdims=True) + 1e-6 - return segmentation, outputs_soft + return cc_seg_labels, cc_softlabels def main( @@ -520,8 +488,8 @@ def main( aseg_name: str | Path, subject_dir: str | Path, slice_selection: SliceSelection = "middle", + #TODO: qc_output_dir is currently unused ?! qc_output_dir: str | Path | None = None, - verbose: bool = False, num_thickness_points: int = 100, subdivisions: list[float] | None = None, subdivision_method: SubdivisionMethod = "shape", @@ -562,8 +530,6 @@ def main( Which slices to process. qc_output_dir : str or Path, optional Directory for quality control outputs, None deactivates qc snapshots. - verbose : bool, default=False - Flag for verbose output. num_thickness_points : int, default=100 Number of points for thickness estimation. subdivisions : list[float], optional @@ -628,52 +594,66 @@ def main( 5. Performs enhanced post-processing analysis. 6. Saves results and visualizations. """ + import sys if subdivisions is None: subdivisions = [1 / 6, 1 / 2, 2 / 3, 3 / 4] - # Set up logging if verbose mode is enabled - if verbose: - logging.setup_logging(None) # Log to stdout only - + subject_dir = Path(subject_dir) + logger.info("Starting corpus callosum analysis pipeline") logger.info(f"Input MRI: {conf_name}") logger.info(f"Input segmentation: {aseg_name}") logger.info(f"Output directory: {subject_dir}") # Convert all paths to Path objects - t1_path = Path(conf_name) - aseg_path = Path(aseg_name) - if subject_dir: - subject_dir = Path(subject_dir) - if save_template_dir: - save_template_dir = Path(save_template_dir) + sd = SubjectDirectory( + subject_dir.parent, + id=subject_dir.name, + conf_name=conf_name, + aseg_name=aseg_name, + save_template_dir=save_template_dir, + upright_volume=upright_volume_path, + cc_segmentation=segmentation_path, + cc_postproc_results=postproc_results_path, + cc_markers=cc_markers_path, + upright_lta=upright_lta_path, + cc_orient_volume_lta=orient_volume_lta_path, + cc_surf=surf_file_path, + cc_overlay=overlay_file_path, + cc_html=cc_html_path, + cc_mesh=vtk_file_path, + cc_orig_segfile=orig_space_segmentation_path, + cc_qc_images=qc_image_path, + cc_thickness_image=thickness_image_path, + cc_softlabels_cc=softlabels_cc_path, + cc_softlabels_fn=softlabels_fn_path, + cc_softlabels_background=softlabels_background_path, + ) # Validate subdivision fractions if any(i < 0 or i > 1 for i in subdivisions): - logger.error(f"Error: Subdivision fractions must be between 0 and 1, but got: {subdivisions}") - import sys + logger.error(f"Subdivision fractions must be between 0 and 1, but got: {subdivisions}") sys.exit(1) #### setup variables - IO_processes = [] + io_futures = [] - orig = nib.load(t1_path) + orig = nib.load(sd.conf_name) - # 5 mm around the midplane - slices_to_analyze = int(np.ceil(5 / orig.header.get_zooms()[0])) - if slices_to_analyze % 2 == 0: - slices_to_analyze += 1 + # 5 mm around the midplane (making sure to get rl by as_closest_canonical) + slices_to_analyze = int(np.ceil(5 / nib.as_closest_canonical(orig).header.get_zooms()[0])) // 2 * 2 + 1 - if verbose: - logger.info( - f"Segmenting {slices_to_analyze} slices (5 mm width at {orig.header.get_zooms()[0]} mm resolution, " - "center around the mid-sagittal plane)" - ) + logger.info( + f"Segmenting {slices_to_analyze} slices (5 mm width at {orig.header.get_zooms()[0]} mm resolution, " + "center around the mid-sagittal plane)" + ) if not is_conform(orig, vox_size='min', img_size=None): - logger.error("Error: MRI is not conformed, please run conform.py or mri_convert to conform the image.") - raise ValueError("MRI is not conformed, please run conform.py or mri_convert to conform the image.") + if is_conform(orig, vox_size=None, img_size=None): + logger.warning("fastsurfer_cc currently requires isotropic images.") + logger.error("MRI is not conformed, please run conform.py or mri_convert to conform the image.") + sys.exit(1) # load models device = find_device(device) @@ -683,73 +663,61 @@ def main( model_localization = localization_inference.load_model(device=device) model_segmentation = segmentation_inference.load_model(device=device) - aseg_nib = nib.load(aseg_path) + aseg_nib = cast(nib.analyze.SpatialImage, nib.load(sd.filename_by_attribute("aseg_name"))) logger.info("Performing centroid registration to fsaverage space") - (orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, - fsaverage_hires_affine, fsaverage_header, fsaverage_vox2ras_tkr) = centroid_registration( + orig2fsavg_vox2vox, orig2fsavg_ras2ras, fsavg_affine, fsavg_header, fsavg_vox2ras_tkr = centroid_registration( aseg_nib ) - - if verbose: - logger.info("Interpolating midplane") - logger.info("Interpolating midplane slices") # this is a fast interpolation to not block the main thread - midslices = interpolate_midplane(orig, orig_fsaverage_vox2vox, slices_to_analyze) + midslices = interpolate_midplane(orig, orig2fsavg_vox2vox, slices_to_analyze) # start saving upright volume - IO_processes.append( - run_in_background( - apply_transform_to_volume, - False, - orig.get_fdata(), - orig_fsaverage_vox2vox, - fsaverage_hires_affine, - None, - upright_volume_path, - output_size=np.array([256, 256, 256]), + if sd.has_attribute("upright_volume"): + io_futures.append( + executor().submit( + apply_transform_to_volume, + orig, + orig2fsavg_vox2vox, + fsavg_affine, + output_path=sd.filename_by_attribute("upright_volume"), + output_size=np.array([256, 256, 256]), + ) ) - ) #### do localization and segmentation inference logger.info("Starting AC/PC localization") ac_coords, pc_coords = localize_ac_pc( - midslices, aseg_nib, orig_fsaverage_vox2vox, model_localization, slices_to_analyze + midslices, aseg_nib, orig2fsavg_vox2vox, model_localization, slices_to_analyze, ) logger.info("Starting corpus callosum segmentation") - segmentation, outputs_soft = segment_cc( - midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, slices_to_analyze + cc_fn_seg_labels, cc_fn_softlabels = segment_cc( + midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, slices_to_analyze, ) # calculate affine for segmentation volume orig_to_seg = np.eye(4) orig_to_seg[0, 3] = -FSAVERAGE_MIDDLE + slices_to_analyze // 2 - seg_affine = fsaverage_hires_affine - seg_affine = seg_affine @ np.linalg.inv(orig_to_seg) + seg_affine = fsavg_affine @ np.linalg.inv(orig_to_seg) # save softlabels - if softlabels_background_path is not None: - if verbose: - logger.info(f"Saving background softlabels to {softlabels_background_path}") - save_nifti_background(IO_processes, outputs_soft[..., 0], seg_affine, orig.header, softlabels_background_path) - if softlabels_cc_path is not None: - if verbose: - logger.info(f"Saving cc softlabels to {softlabels_cc_path}") - save_nifti_background(IO_processes, outputs_soft[..., 1], seg_affine, orig.header, softlabels_cc_path) - if softlabels_fn_path is not None: - if verbose: - logger.info(f"Saving fornix softlabels to {softlabels_fn_path}") - save_nifti_background(IO_processes, outputs_soft[..., 2], seg_affine, orig.header, softlabels_fn_path) - + for i, (attr, name) in enumerate((("background",) * 2, ("cc", "Corpus Callosum"), ("fn", "Fornix"))): + if sd.has_attribute(f"cc_softlabels_{attr}"): + logger.info(f"Saving {name} softlabels to {sd.filename_by_attribute(f'cc_softlabels_{attr}')}") + io_futures.append(executor().submit( + nib.save, + nib.MGHImage(cc_fn_softlabels[..., i], seg_affine, orig.header), + sd.filename_by_attribute(f"cc_softlabels_{attr}"), + )) # Create a temporary segmentation image with proper affine for enhanced postprocessing # Process slices based on selection mode logger.info(f"Processing slices with selection mode: {slice_selection}") - slice_results, slice_io_processes = process_slices( - segmentation=segmentation, + slice_results, slice_io_futures = recon_cc_surf_measures_multi( + segmentation=cc_fn_seg_labels, slice_selection=slice_selection, - temp_seg_affine=fsaverage_hires_affine, + temp_seg_affine=fsavg_affine, midslices=midslices, ac_coords=ac_coords, pc_coords=pc_coords, @@ -757,55 +725,48 @@ def main( subdivisions=subdivisions, subdivision_method=subdivision_method, contour_smoothing=contour_smoothing, - qc_image_path=qc_image_path, - one_debug_image=True, - surf_file_path=surf_file_path, - overlay_file_path=overlay_file_path, - cc_html_path=cc_html_path, - vtk_file_path=vtk_file_path, - thickness_image_path=thickness_image_path, vox_size=orig.header.get_zooms(), - vox2ras_tkr=fsaverage_vox2ras_tkr, - verbose=verbose, - save_template=save_template_dir, + vox2ras_tkr=fsavg_vox2ras_tkr, + subject_dir=sd, ) - IO_processes.extend(slice_io_processes) + io_futures.extend(slice_io_futures) outer_contours = [slice_result['split_contours'][0] for slice_result in slice_results] - if len(outer_contours) > 1 and not check_area_changes(outer_contours, verbose=True): - logger.warning("Large area changes detected between consecutive slices, " - "this is likely due to a segmentation error.") + if len(outer_contours) > 1 and not check_area_changes(outer_contours): + logger.warning( + "Large area changes detected between consecutive slices, this is likely due to a segmentation error." + ) # Get middle slice result for backward compatibility middle_slice_result = slice_results[len(slice_results) // 2] if len(middle_slice_result['split_contours']) <= 5: - subdivision_mask = make_subdivision_mask(segmentation.shape[1:], - middle_slice_result['split_contours'], - orig.header.get_zooms()) + subdivision_mask = make_subdivision_mask( + cc_fn_seg_labels.shape[1:], + middle_slice_result['split_contours'], + orig.header.get_zooms(), + ) else: logger.warning("Too many subsegments for lookup table, skipping sub-divion of output segmentation.") subdivision_mask = None - # map soft labels to original space (in parallel because this takes a while) - IO_processes.append( - run_in_background( - map_softlabels_to_orig, - debug=False, - outputs_soft=outputs_soft, - orig_fsaverage_vox2vox=orig_fsaverage_vox2vox, - orig=orig, - slices_to_analyze=slices_to_analyze, - orig_space_segmentation_path=orig_space_segmentation_path, - fsaverage_middle=FSAVERAGE_MIDDLE, - subdivision_mask=subdivision_mask, - ) - ) - - save_nifti_background(IO_processes, segmentation, seg_affine, orig.header, segmentation_path) - + io_futures.append(executor().submit( + map_softlabels_to_orig, + outputs_soft=cc_fn_softlabels, + orig_fsaverage_vox2vox=orig2fsavg_vox2vox, + orig=orig, + slices_to_analyze=slices_to_analyze, + orig_space_segmentation_path=orig_space_segmentation_path, + fsaverage_middle=FSAVERAGE_MIDDLE, + subdivision_mask=subdivision_mask, + )) + io_futures.append(executor().submit( + nib.save, + nib.MGHImage(cc_fn_seg_labels, seg_affine, orig.header), + sd.filename_by_attribute("cc_segmentation"), + )) METRICS = [ "areas", @@ -820,38 +781,30 @@ def main( ] # Record key metrics for middle slice - output_metrics_middle_slice = { - metric: middle_slice_result[metric] for metric in METRICS - } + output_metrics_middle_slice = {metric: middle_slice_result[metric] for metric in METRICS} # Create enhanced output dictionary with all slice results per_slice_output_dict = { "slices": [ - convert_numpy_to_json_serializable( - { - metric: result[metric] for metric in METRICS - } - ) + convert_numpy_to_json_serializable({metric: result[metric] for metric in METRICS}) for result in slice_results ], } ########## Save outputs ########## - additional_metrics = {} if len(outer_contours) > 1: cc_volume_voxel = segmentation_postprocessing.get_cc_volume_voxel( desired_width_mm=5, - cc_mask=segmentation == CC_LABEL, + cc_mask=cc_fn_seg_labels == CC_LABEL, voxel_size=orig.header.get_zooms() ) cc_volume_contour = segmentation_postprocessing.get_cc_volume_contour( cc_contours=outer_contours, voxel_size=orig.header.get_zooms() ) - if verbose: - logger.info(f"CC volume voxel: {cc_volume_voxel}") - logger.info(f"CC volume contour: {cc_volume_contour}") + logger.info(f"CC volume voxel: {cc_volume_voxel}") + logger.info(f"CC volume contour: {cc_volume_contour}") additional_metrics["cc_5mm_volume"] = cc_volume_voxel additional_metrics["cc_5mm_volume_pv_corrected"] = cc_volume_contour @@ -862,7 +815,7 @@ def main( ac_coords_3d = np.hstack((FSAVERAGE_MIDDLE, ac_coords)) pc_coords_3d = np.hstack((FSAVERAGE_MIDDLE, pc_coords)) standardized_to_orig_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig = ( - get_mapping_to_standard_space(orig, ac_coords_3d, pc_coords_3d, orig_fsaverage_vox2vox) + calc_mapping_to_standard_space(orig, ac_coords_3d, pc_coords_3d, orig2fsavg_vox2vox) ) # write output dict as csv @@ -883,52 +836,79 @@ def main( # Convert numpy arrays to lists for JSON serialization output_metrics_middle_slice = convert_numpy_to_json_serializable(output_metrics_middle_slice | additional_metrics) - logger.info(f"Saving CC markers to {cc_markers_path}") - with open(cc_markers_path, "w") as f: + logger.info(f"Saving CC markers to {sd.filename_by_attribute('cc_markers')}") + with open(sd.filename_by_attribute("cc_markers"), "w") as f: json.dump(output_metrics_middle_slice, f, indent=4) - per_slice_output_dict = convert_numpy_to_json_serializable(per_slice_output_dict | additional_metrics) # Save slice-wise postprocessing results to JSON - with open(postproc_results_path, "w") as f: + with open(sd.filename_by_attribute("cc_postproc_results"), "w") as f: json.dump(per_slice_output_dict, f, indent=4) - if verbose: - logger.info(f"Multiple slice post-processing results saved to {postproc_results_path}") + logger.info(f"Multiple slice post-processing results saved to {sd.filename_by_attribute('cc_postproc_results')}") # save lta to fsaverage space - logger.info(f"Saving LTA to fsaverage space: {upright_lta_path}") - lta.writeLTA(upright_lta_path, orig_fsaverage_ras2ras, aseg_path, aseg_nib.header, "fsaverage", fsaverage_header) + logger.info(f"Saving LTA to fsaverage space: {sd.filename_by_attribute('upright_lta')}") + lta.writeLTA( + sd.filename_by_attribute("upright_lta"), + orig2fsavg_ras2ras, + sd.filename_by_attribute("aseg_name"), + aseg_nib.header, + "fsaverage", + fsavg_header, + ) # save lta to standardized space (fsaverage + nodding + ac to center) - orig_to_standardized_ras2ras = ( - orig.affine @ np.linalg.inv(standardized_to_orig_vox2vox) @ np.linalg.inv(orig.affine) - ) - logger.info(f"Saving LTA to standardized space: {orient_volume_lta_path}") + orig2standardized_ras2ras = orig.affine @ np.linalg.inv(standardized_to_orig_vox2vox) @ np.linalg.inv(orig.affine) + logger.info(f"Saving LTA to standardized space: {sd.filename_by_attribute('cc_orient_volume_lta')}") lta.writeLTA( - orient_volume_lta_path, orig_to_standardized_ras2ras, t1_path, orig.header, t1_path, orig.header + sd.filename_by_attribute("cc_orient_volume_lta"), + orig2standardized_ras2ras, + sd.conf_name, + orig.header, + sd.conf_name, + orig.header, ) - for process in IO_processes: - if process is not None: - process.join() - + for e in filter(lambda x: x and isinstance(x, Exception), (fut.exception() for fut in io_futures)): + logger.exception(e) + logger.info("CorpusCallosum analysis pipeline completed successfully") if __name__ == "__main__": options = options_parse() - main_args = vars(options) - - # Remove parser_defaults arguments that are not needed by main() - main_args.pop("sd", None) - main_args.pop("out_dir", None) - main_args.pop("sid", None) - # Rename keys to match main function parameters - main_args["t1_path"] = main_args.pop("t1") - main_args["aseg_path"] = main_args.pop("aseg_name") - main_args["output_dir"] = main_args.pop("subject_dir", ".") - - main(**main_args) + # Set up logging if verbose mode is enabled + logging.setup_logging(None, options.verbose) # Log to stdout only + + main( + conf_name=options.conf_name, + aseg_name=options.aseg_name, + subject_dir=options.subject_dir, + slice_selection=options.slice_selection, + qc_output_dir=options.qc_output_dir, + num_thickness_points=options.num_thickness_points, + subdivisions=options.subdivisions, + subdivision_method=options.subdivision_method, + contour_smoothing=options.contour_smoothing, + save_template_dir=options.save_template_dir, + device=options.device, + upright_volume_path=options.upright_volume_path, + segmentation_path=options.segmentation_path, + postproc_results_path=options.postproc_results_path, + cc_markers_path=options.cc_markers_path, + upright_lta_path=options.upright_lta_path, + orient_volume_lta_path=options.orient_volume_lta_path, + surf_file_path=options.surf_file_path, + overlay_file_path=options.overlay_file_path, + cc_html_path=options.cc_html_path, + vtk_file_path=options.vtk_file_path, + orig_space_segmentation_path=options.orig_space_segmentation_path, + qc_image_path=options.qc_image_path, + thickness_image_path=options.thickness_image_path, + softlabels_cc_path=options.softlabels_cc_path, + softlabels_fn_path=options.softlabels_fn_path, + softlabels_background_path=options.softlabels_background_path, + ) diff --git a/CorpusCallosum/localization/localization_inference.py b/CorpusCallosum/localization/localization_inference.py index 4b8cf025..d4c0ae29 100644 --- a/CorpusCallosum/localization/localization_inference.py +++ b/CorpusCallosum/localization/localization_inference.py @@ -18,12 +18,16 @@ import torch from monai import transforms from monai.networks.nets import DenseNet +from numpy import typing as npt -from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT from CorpusCallosum.transforms.localization_transforms import CropAroundACPCFixedSize from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults from FastSurferCNN.download_checkpoints import main as download_checkpoints +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT + +PATCH_SIZE = (64, 64) + def load_model(device: torch.device) -> DenseNet: """Load trained numerical localization model from checkpoint. @@ -86,18 +90,14 @@ def get_transforms() -> transforms.Compose: """ tr = [ transforms.ScaleIntensityd(keys=['image'], minv=0, maxv=1), - CropAroundACPCFixedSize( - keys=['image'], - fixed_size=(64, 64), - random_translate=0, - ), + CropAroundACPCFixedSize(keys=['image'], fixed_size=PATCH_SIZE, random_translate=0), ] return transforms.Compose(tr) def preprocess_volume( image_volume: np.ndarray, - center_pt: np.ndarray, + center_pt: npt.NDArray[float], transform: transforms.Transform | None = None ) -> dict[str, torch.Tensor]: """Preprocess a volume for inference. @@ -105,9 +105,9 @@ def preprocess_volume( Parameters ---------- image_volume : np.ndarray - Input image volume + Input image volume of shape (W, W, D) in RAS. center_pt : np.ndarray - Center point coordinates for cropping + Center point coordinates for cropping on the slice with shape (3,). transform : transforms.Transform or None, optional Custom transform pipeline, by default None. If None, uses default transforms from get_transforms(). @@ -120,24 +120,25 @@ def preprocess_volume( if transform is None: transform = get_transforms() - sample = {"image": image_volume, "AC_center": center_pt, "PC_center": center_pt} + sample = {"image": image_volume[None], "AC_center": center_pt[1:][None], "PC_center": center_pt[1:][None]} # Apply transforms transformed = transform(sample) # Add batch dimension if needed if torch.is_tensor(transformed["image"]): - if len(transformed["image"].shape) == 3: + if transformed["image"].ndim == 3: transformed["image"] = transformed["image"].unsqueeze(0) return transformed -def run_inference(model: torch.nn.Module, - image_volume: np.ndarray, - third_ventricle_center: np.ndarray, - device: torch.device | None = None, - transform: transforms.Transform | None = None - ) -> tuple[np.ndarray, np.ndarray, np.ndarray, tuple[int, int]]: +def run_inference( + model: torch.nn.Module, + image_volume: np.ndarray, + third_ventricle_center: np.ndarray, + device: torch.device | None = None, + transform: transforms.Transform | None = None + ) -> tuple[npt.NDArray[float], npt.NDArray[float], np.ndarray, tuple[int, int]]: """ Run inference on an image volume @@ -169,43 +170,34 @@ def run_inference(model: torch.nn.Module, device = next(model.parameters()).device #device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # prepend zero to third_ventricle_center third_ventricle_center = np.concatenate([np.zeros(1), third_ventricle_center]) # Preprocess - t_dict = preprocess_volume(image_volume[None], third_ventricle_center, transform) - + t_dict = preprocess_volume(image_volume, third_ventricle_center, transform) transformed_original = t_dict['image'] inputs = transformed_original.to(device) - inputs = inputs.transpose(0, 1) batch_size, channels, height, width = inputs.shape inputs = inputs.unfold(0, 3, 1).swapdims(0, 1).reshape(-1, 3*channels, height, width) - # Run inference with torch.no_grad(): - outputs = model(inputs) - - # Scale outputs to image size - # img_size = torch.tensor([inputs.shape[2], inputs.shape[3], - # inputs.shape[2], inputs.shape[3]], - # dtype=torch.float32, - # device=device) - outputs = outputs * 64 - - t_crops = [[t_dict['crop_left'], t_dict['crop_top'], t_dict['crop_left'], t_dict['crop_top']]] - outs: np.ndarray = (outputs + torch.tensor(t_crops, dtype=outputs.dtype, device=outputs.device)).numpy() - return outs[:, :2], outs[:, 2:], inputs.numpy(), tuple(int(t_dict[k].item()) for k in ['crop_left', 'crop_top']) - - -def run_inference_on_slice(model: DenseNet, - image_slice: np.ndarray, - center_pt: np.ndarray, - debug_output: str | None = None) -> tuple[np.ndarray, np.ndarray]: + outputs = model(inputs) * torch.as_tensor([PATCH_SIZE + PATCH_SIZE], device=device) + + t_crops = [(t_dict['crop_left'] + t_dict['crop_top']) * 2] + outs: npt.NDArray[float] = outputs.cpu().numpy() + np.asarray(t_crops, dtype=float) + return outs[:, :2], outs[:, 2:], inputs.cpu().numpy(), (t_dict["crop_left"][0], t_dict["crop_top"][0]) + + +def run_inference_on_slice( + model: DenseNet, + image_slice: np.ndarray, + center_pt: np.ndarray, + debug_output: str | None = None, +) -> tuple[npt.NDArray[float], npt.NDArray[float]]: """Run inference on a single slice to detect AC and PC points. Parameters @@ -213,7 +205,7 @@ def run_inference_on_slice(model: DenseNet, model : torch.nn.Module Trained model for AC-PC detection image_slice : np.ndarray - 3D image slice to run inference on + 3D image mid-slices to run inference on in RAS. center_pt : np.ndarray Initial center point estimate for cropping debug_output : str, optional @@ -228,7 +220,7 @@ def run_inference_on_slice(model: DenseNet, """ # Run inference - pc_coords, ac_coords, _, (crop_left, crop_top) = run_inference(model, image_slice, center_pt) + pc_coords, ac_coords, *_ = run_inference(model, image_slice, center_pt) center_pt = np.mean(np.concatenate([ac_coords, pc_coords], axis=0), axis=0) pc_coords, ac_coords, _, (crop_left, crop_top) = run_inference(model, image_slice, center_pt) pc_coords = np.mean(pc_coords, axis=0) diff --git a/CorpusCallosum/registration/mapping_helpers.py b/CorpusCallosum/registration/mapping_helpers.py index 99958ea0..a9a1a7cc 100644 --- a/CorpusCallosum/registration/mapping_helpers.py +++ b/CorpusCallosum/registration/mapping_helpers.py @@ -120,7 +120,7 @@ def apply_transform_to_pt(pts: npt.NDArray[float], T: npt.NDArray[float], inv: b return (T @ np.concatenate([pts, np.ones((1, pts.shape[1]))]))[:3] -def get_mapping_to_standard_space( +def calc_mapping_to_standard_space( orig: "nib.Nifti1Image", ac_coords_3d: npt.NDArray[float], pc_coords_3d: npt.NDArray[float], @@ -205,26 +205,26 @@ def get_mapping_to_standard_space( def apply_transform_to_volume( - volume: np.ndarray, + orig_image: nib.analyze.SpatialImage, transform: npt.NDArray[float], affine: npt.NDArray[float], - header: nib.freesurfer.mghformat.MGHHeader, + header: nib.freesurfer.mghformat.MGHHeader | None = None, output_path: str | Path | None = None, output_size: np.ndarray | None = None, order: int = 1 -) -> np.ndarray: +) -> npt.NDArray[float]: """Apply transformation to a volume and save the result. Parameters ---------- - volume : np.ndarray - Input volume data. + orig_image : nibabel.analyze.SpatialImage + Input volume. transform : np.ndarray Transformation matrix to apply. affine : np.ndarray Affine matrix for the output image. - header : nib.freesurfer.mghformat.MGHHeader - Header for the output image. + header : nib.freesurfer.mghformat.MGHHeader, optional + Header for the output image, if None will default to orig_image header. output_path : str or Path, optional Path to save transformed volume. output_size : np.ndarray, optional @@ -234,7 +234,7 @@ def apply_transform_to_volume( Returns ------- - np.ndarray + npt.NDArray[float] Transformed volume data. Notes @@ -242,18 +242,19 @@ def apply_transform_to_volume( Uses scipy.ndimage.affine_transform for the transformation. If output_path is provided, saves the result as a MGH file. """ - if output_size is None: - output_size = np.array(volume.shape) + output_size = np.array(orig_image.shape) + if header is None: + header = orig_image.header transformed = affine_transform( - volume.astype(np.float32), + orig_image.get_data(), np.linalg.inv(transform), output_shape=output_size, order=order, ) if output_path is not None: logger.info(f"Saving transformed volume to {output_path}") - nib.save(nib.MGHImage(transformed, affine, header), output_path) + nib.save(nib.MGHImage(transformed.astype(orig_image.get_data_dtype()), affine, header), output_path) return transformed @@ -293,12 +294,12 @@ def make_affine(simpleITKImage: 'sitk.Image') -> npt.NDArray[float]: def map_softlabels_to_orig( outputs_soft: npt.NDArray[float], orig_fsaverage_vox2vox: npt.NDArray[float], - orig: np.ndarray, + orig: nib.analyze.SpatialImage, slices_to_analyze: int, orig_space_segmentation_path: str | Path | None = None, fsaverage_middle: int = 128, - subdivision_mask: np.ndarray | None = None -) -> np.ndarray: + subdivision_mask: npt.NDArray[int] | None = None +) -> npt.NDArray[int]: """Map soft labels back to original image space and apply post-processing. Parameters @@ -307,7 +308,7 @@ def map_softlabels_to_orig( Soft label predictions. orig_fsaverage_vox2vox : np.ndarray Original to fsaverage space transformation. - orig : np.ndarray + orig : nibabel.analyze.SpatialImage Original image. slices_to_analyze : int Number of slices to analyze. @@ -315,12 +316,12 @@ def map_softlabels_to_orig( Path to save segmentation in original space. fsaverage_middle : int, default=128 Middle slice index in fsaverage space. - subdivision_mask : np.ndarray, optional + subdivision_mask : npt.NDArray[int], optional Mask for subdividing regions. Returns ------- - np.ndarray + npt.NDArray[int] Final segmentation in original image space. Notes diff --git a/CorpusCallosum/segmentation/segmentation_inference.py b/CorpusCallosum/segmentation/segmentation_inference.py index 8878df53..6aed9ce4 100644 --- a/CorpusCallosum/segmentation/segmentation_inference.py +++ b/CorpusCallosum/segmentation/segmentation_inference.py @@ -17,6 +17,7 @@ import numpy as np import torch from monai import transforms +from numpy import typing as npt from CorpusCallosum.data import constants from CorpusCallosum.transforms.segmentation_transforms import CropAroundACPC @@ -24,6 +25,7 @@ from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults from FastSurferCNN.download_checkpoints import main as download_checkpoints from FastSurferCNN.models.networks import FastSurferVINN +from FastSurferCNN.utils.common import thread_executor as executor def load_model(device: torch.device | None = None) -> FastSurferVINN: @@ -31,9 +33,6 @@ def load_model(device: torch.device | None = None) -> FastSurferVINN: Parameters ---------- - checkpoint_path : str or None, optional - Path to model checkpoint, by default None. - If None, downloads and uses default checkpoint. device : torch.device or None, optional Device to load model to, by default None. If None, uses CUDA if available, else CPU. @@ -68,10 +67,10 @@ def load_model(device: torch.device | None = None) -> FastSurferVINN: model = FastSurferVINN(params) download_checkpoints(cc=True) - cc_config = load_checkpoint_config_defaults( - "checkpoint", - filename=CC_YAML, - ) + cc_config: dict[str, Path] = load_checkpoint_config_defaults( + "checkpoint", + filename=CC_YAML, + ) checkpoint_path = constants.FASTSURFER_ROOT / cc_config['segmentation'] weights = torch.load(checkpoint_path, weights_only=True, map_location=device) @@ -84,12 +83,12 @@ def load_model(device: torch.device | None = None) -> FastSurferVINN: def run_inference( model: FastSurferVINN, image_slice: np.ndarray, - AC_center: np.ndarray, - PC_center: np.ndarray, + ac_center: np.ndarray, + pc_center: np.ndarray, voxel_size: float, device: torch.device | None = None, transform: transforms.Transform | None = None -) -> dict[str, np.ndarray]: +) -> tuple[npt.NDArray[int], npt.NDArray[float], npt.NDArray[float]]: """Run inference on a single image slice. Parameters @@ -97,10 +96,10 @@ def run_inference( model : FastSurferVINN Trained model image_slice : np.ndarray - Input image as numpy array - AC_center : np.ndarray + LIA-oriented input image as numpy array of shape (L, I, A). + ac_center : np.ndarray Anterior commissure coordinates - PC_center : np.ndarray + pc_center : np.ndarray Posterior commissure coordinates voxel_size : float Voxel size in mm @@ -112,86 +111,50 @@ def run_inference( Returns ------- - dict[str, np.ndarray] - Dictionary containing: - - segmentation : Binary segmentation map - - landmarks : Predicted landmark coordinates + seg_labels : npt.NDArray[int] + The segmentation result. + inputs : npt.NDArray[float] + The inputs to the model. + soft_labels : npt.NDArray[float] + The softlabel output. """ if device is None: - #device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = next(model.parameters()).device - def crop_around_acpc(img: np.ndarray, - ac: np.ndarray, - pc: np.ndarray, - vox_size: float) -> dict[str, np.ndarray]: - """Crop image around AC-PC points. - - Parameters - ---------- - img : np.ndarray - Input image - ac : np.ndarray - Anterior commissure coordinates - pc : np.ndarray - Posterior commissure coordinates - vox_size : float - Voxel size in mm - - Returns - ------- - dict[str, np.ndarray] - Dictionary containing cropped image and metadata - """ - return CropAroundACPC(keys=['image'], padding_mm=35, random_translate=0)( - {'image': img, 'AC_center': ac, 'PC_center': pc, 'res': vox_size} - ) + crop_around_acpc = CropAroundACPC(keys=['image'], padding_mm=35, random_translate=0) + to_discrete = transforms.AsDiscrete(argmax=True, to_onehot=3) # Preprocess slice - inputs = torch.from_numpy(image_slice[:,None,:256,:256]) # artifact from training script - crop_dict = crop_around_acpc(inputs, AC_center, PC_center, voxel_size) - inputs, to_pad = crop_dict['image'], crop_dict['to_pad'] - inputs = transforms.utils.rescale_array(inputs, 0, 1, dtype=np.float32) - inputs = inputs.to(device) - - post_trans = transforms.Compose( - [transforms.Activations(softmax=True), transforms.AsDiscrete(argmax=True, to_onehot=3)] - ) + _inputs = torch.from_numpy(image_slice[:,None,:256,:256]) # artifact from training script + sample = {'image': _inputs, 'AC_center': ac_center, 'PC_center': pc_center, 'res': voxel_size} + sample_cropped = crop_around_acpc(sample) + _inputs, to_pad = sample_cropped['image'], sample_cropped['to_pad'] + _inputs = transforms.utils.rescale_array(_inputs, 0, 1, dtype=np.float32).to(device) # split into slices with 9 channels each # Generate views with sliding window of 9 slices - batch_size, channels, height, width = inputs.shape - inputs = inputs.unfold(0, 9, 1).swapdims(0, 1).reshape(-1, 9*channels, height, width) + batch_size, channels, height, width = _inputs.shape + _inputs = _inputs.unfold(0, 9, 1).swapdims(-1, 1).reshape(-1, 9*channels, height, width) # Post-process outputs with torch.no_grad(): - scale_factors = torch.ones((inputs.shape[0], 2), device=device) / voxel_size + scale_factors = torch.ones((_inputs.shape[0], 2), device=device) / voxel_size - outputs = model(inputs, scale_factor=scale_factors) - - # average the outputs along the batch dimension - outputs_avg = torch.mean(outputs, dim=0, keepdim=True) - - outputs_soft = outputs.cpu().numpy() #transforms.Activations(softmax=True)(outputs) # non_discrete outputs - outputs = torch.stack([post_trans(i) for i in outputs]) - outputs_avg = torch.stack([post_trans(i) for i in outputs_avg]) + _logits = model(_inputs, scale_factor=scale_factors) + _softlabels = transforms.Activations(softmax=True, dim=1)(_logits) + softlabels = _softlabels.cpu().numpy() + _labels = torch.stack([to_discrete(i) for i in _softlabels]) + # Pad back to original size, to_pad is a tuple[int, int, int, int] pad_tuples = ((0, 0),) * 2 + (to_pad[:2], to_pad[2:]) - outputs = np.pad(outputs, pad_tuples, mode='constant', constant_values=0) - outputs_avg = np.pad(outputs_avg, pad_tuples, mode='constant', constant_values=0) - outputs_soft = np.pad(outputs_soft, pad_tuples, mode='constant', constant_values=0) + labels = np.pad(_labels.cpu().numpy(), pad_tuples, mode='constant', constant_values=0) + softlabels = np.pad(softlabels, pad_tuples, mode='constant', constant_values=0) - return ( - outputs.transpose(0,2,3,1), - inputs.cpu().numpy().transpose(0,2,3,1), - outputs_avg.transpose(0,2,3,1), - outputs_soft.transpose(0,2,3,1), - ) + return [x.transpose(0, 2, 3, 1) for x in (labels, _inputs.cpu().numpy(), softlabels)] def load_validation_data(path): - from concurrent.futures import ThreadPoolExecutor import pandas as pd data = pd.read_csv(path, index_col=0, header=None) @@ -216,7 +179,7 @@ def _load(label_path: str | Path) -> int: return last_nonzero - first_nonzero else: return label_img.shape[0] - label_widths = ThreadPoolExecutor().map(_load, data["label"]) + label_widths = executor().map(_load, data["label"]) return images, ac_centers, pc_centers, label_widths, labels, subj_ids @@ -231,25 +194,27 @@ def one_hot_to_label(one_hot, label_ids=None): -def run_inference_on_slice(model: FastSurferVINN, - test_slice: np.ndarray, - AC_center: np.ndarray, - PC_center: np.ndarray, - voxel_size: float) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: +def run_inference_on_slice( + model: FastSurferVINN, + test_slice: np.ndarray, + ac_center: npt.NDArray[float], + pc_center: npt.NDArray[float], + voxel_size: float, +) -> tuple[npt.NDArray[int], np.ndarray, npt.NDArray[float]]: """Run inference on a single slice. Parameters ---------- model : FastSurferVINN - Trained model for inference + Trained model for inference. test_slice : np.ndarray - Input image slice - AC_center : np.ndarray - Anterior commissure coordinates - PC_center : np.ndarray - Posterior commissure coordinates + Input image slice. + ac_center : npt.NDArray[float] + Anterior commissure coordinates (Inferior and Anterior values). + pc_center : npt.NDArray[float] + Posterior commissure coordinates (Inferior and Posterior values). voxel_size : float - Voxel size in mm + Voxel size in mm. Returns ------- @@ -257,17 +222,15 @@ def run_inference_on_slice(model: FastSurferVINN, Label map after one-hot conversion inputs: np.ndarray Preprocessed input image - outputs_avg: np.ndarray - Averaged model outputs - outputs_soft: np.ndarray + outputs_soft: npt.NDArray[float] Softlabel outputs (non-discrete) """ # add zero in front of AC_center and PC_center - AC_center = np.concatenate([np.zeros(1), AC_center]) - PC_center = np.concatenate([np.zeros(1), PC_center]) + ac_center = np.concatenate([np.zeros(1), ac_center]) + pc_center = np.concatenate([np.zeros(1), pc_center]) - results, inputs, outputs_avg, outputs_soft = run_inference(model, test_slice, AC_center, PC_center, voxel_size) + results, inputs, outputs_soft = run_inference(model, test_slice, ac_center, pc_center, voxel_size) results = one_hot_to_label(results) - return results, inputs, outputs_avg, outputs_soft + return results, inputs, outputs_soft diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index f084eaaa..90139496 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -19,7 +19,7 @@ from skimage.measure import label import FastSurferCNN.utils.logging as logging -from CorpusCallosum.data.constants import CC_LABEL, FORNIX_LABEL +from CorpusCallosum.data.constants import CC_LABEL logger = logging.get_logger(__name__) @@ -204,8 +204,6 @@ def connect_nearby_components(seg_arr: np.ndarray, max_connection_distance: floa # Use the largest component as the reference main_component_id = component_sizes[0][0] - - logger.info(f"Found {len(component_ids)} disconnected components. " f"Attempting to connect smaller components to main component (size: {component_sizes[0][1]})") @@ -425,7 +423,7 @@ def get_cc_volume_contour(cc_contours: list[np.ndarray], return integrate.simpson(areas, x=measurement_points) -def get_largest_cc( +def extract_largest_connected_component( seg_arr: np.ndarray, max_connection_distance: float = 3.0 ) -> np.ndarray: @@ -486,24 +484,24 @@ def get_largest_cc( return largest_cc def clean_cc_segmentation( - seg_arr: np.ndarray, + seg_arr: npt.NDArray[int], max_connection_distance: float = 3.0 ) -> tuple[np.ndarray, np.ndarray]: """Clean corpus callosum segmentation by removing non-connected components. Parameters ---------- - seg_arr : np.ndarray + seg_arr : npt.NDArray[int] Input segmentation array with CC (192) and fornix (250) labels - max_connection_distance : float, optional - Maximum distance to connect components, by default 3.0 + max_connection_distance : float, default=3.0 + Maximum distance to connect components. Returns ------- - tuple[np.ndarray, np.ndarray] - - clean_seg : Cleaned segmentation array with only the largest - connected component of CC and fornix - - mask : Binary mask of the largest connected component + clean_seg : np.NDArray[int] + Cleaned segmentation array with only the largest connected component of CC and fornix. + mask : npt.NDArray[bool] + Binary mask of the largest connected component. Notes ----- @@ -513,24 +511,16 @@ def clean_cc_segmentation( 3. Adds the fornix (label 250) 4. Removes non-connected components from the combined CC and fornix """ - # Remove non connected components from the CC alone, with minimal connections - cc_seg = np.zeros_like(seg_arr) - cc_seg[seg_arr == CC_LABEL] = CC_LABEL + from functools import partial - cc_label_cleaned = np.zeros_like(cc_seg) - for i in range(cc_seg.shape[0]): - cc_label_cleaned[i] = get_largest_cc(cc_seg[None,i], max_connection_distance) - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots(1,3) - # ax[0].imshow(cc_seg[i]) - # ax[1].imshow(mask[i]) - # ax[2].imshow(cc_seg[i] - mask[i]*CC_LABEL) # difference between pre and post clean - # plt.show() + extract_largest = partial(extract_largest_connected_component, max_connection_distance=max_connection_distance) + # Remove non connected components from the CC alone, with minimal connections + mask = seg_arr == CC_LABEL + cc_seg = mask.astype(int) * CC_LABEL + cc_label_cleaned = np.concatenate([extract_largest(seg[None]) * CC_LABEL for seg in cc_seg], axis=0) # Add fornix to the CC labels - clean_seg = np.zeros_like(seg_arr) - clean_seg[cc_label_cleaned > 0] = CC_LABEL - clean_seg[seg_arr == FORNIX_LABEL] = FORNIX_LABEL + clean_seg = np.where(mask, cc_label_cleaned, seg_arr) return clean_seg, cc_label_cleaned > 0 diff --git a/CorpusCallosum/shape/cc_endpoint_heuristic.py b/CorpusCallosum/shape/cc_endpoint_heuristic.py index b9d18818..d88f0e6a 100644 --- a/CorpusCallosum/shape/cc_endpoint_heuristic.py +++ b/CorpusCallosum/shape/cc_endpoint_heuristic.py @@ -228,13 +228,11 @@ def get_endpoints( rotated_PC_2d = (rot_matrix @ pc_centered) + origin_point rotated_AC_2d = (rot_matrix @ ac_centered) + origin_point - # Add z=0 coordinate to make 3D, then remove it after resampling contour_3d = np.vstack([contour, np.zeros(contour.shape[1])]) contour_3d = lapy.tria_mesh.TriaMesh._TriaMesh__resample_polygon(contour_3d.T, 701).T contour = contour_3d[:2] - contour = contour[:, :-1] diff --git a/CorpusCallosum/shape/cc_mesh.py b/CorpusCallosum/shape/cc_mesh.py index dc6b7dad..ed41150e 100644 --- a/CorpusCallosum/shape/cc_mesh.py +++ b/CorpusCallosum/shape/cc_mesh.py @@ -21,7 +21,6 @@ import nibabel as nib import numpy as np import plotly.graph_objects as go -import pyrr import scipy.interpolate from scipy.ndimage import gaussian_filter1d @@ -29,10 +28,18 @@ from CorpusCallosum.shape.cc_endpoint_heuristic import smooth_contour from CorpusCallosum.shape.cc_thickness import HiddenPrints, make_mesh_from_contour +try: + from pyrr import Matrix44 + HAS_PYRR = True +except ImportError: + HAS_PYRR = False + class Matrix44(np.ndarray): + pass + logger = logging.get_logger(__name__) -class CC_Mesh(lapy.TriaMesh): +class CCMesh(lapy.TriaMesh): """A class for representing and manipulating corpus callosum (CC) meshes. This class extends lapy.TriaMesh to provide specialized functionality for working with @@ -73,6 +80,7 @@ def __init__(self, num_slices): num_slices : int Number of slices in the corpus callosum mesh """ + super().__init__(np.zeros((3, 3)), np.zeros((3, 3), dtype=int)) self.contours = [None] * num_slices self.thickness_values = [None] * num_slices self.start_end_idx = [None] * num_slices @@ -139,7 +147,7 @@ def set_resolution(self, resolution: float): def plot_mesh( self, - output_path: str | None = None, + output_path: Path | str | None = None, colormap: str = "red_to_yellow", thickness_overlay: bool = True, show_contours: bool = False, @@ -156,7 +164,7 @@ def plot_mesh( Parameters ---------- - output_path : str, optional + output_path : Path, str, optional Path to save the plot. If None, displays the plot interactively. colormap : str, optional Which colormap to use, by default "red_to_yellow". @@ -439,13 +447,12 @@ def plot_mesh( fig.write_html(output_path) # Save as interactive HTML else: # For non-interactive display, save to a temporary HTML and open in browser - import os import tempfile import webbrowser - temp_path = os.path.join(tempfile.gettempdir(), "cc_mesh_plot.html") + temp_path = Path(tempfile.gettempdir()) / "cc_mesh_plot.html" fig.write_html(temp_path) - webbrowser.open("file://" + temp_path) + webbrowser.open(f"file://{temp_path}") def get_contour_edge_lengths(self, contour_idx: int) -> np.ndarray: """Get the lengths of the edges of a contour. @@ -1200,13 +1207,13 @@ def set_mesh(self, self.mesh_vertex_colors = np.array([]) @staticmethod - def __create_cc_viewmat() -> pyrr.Matrix44: + def __create_cc_viewmat() -> "Matrix44": """Create the view matrix for a nice view of the corpus callosum. Returns ------- - pyrr.Matrix44 - 4x4 view matrix that provides a standard view of the corpus callosum + Matrix44 + 4x4 view matrix that provides a standard view of the corpus callosum (from pyrr). Notes ----- @@ -1218,42 +1225,46 @@ def __create_cc_viewmat() -> pyrr.Matrix44: - -8 degrees around z-axis 3. Adds a small translation for better centering """ + + if not HAS_PYRR: + raise ImportError("Pyrr not installed, install pyrr with `pip install pyrr`.") + viewLeft = np.array([[0, 0, -1, 0], [-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) # left w top up // right - transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) + transl = Matrix44.from_translation((0, 0, 0.4)) viewmat = transl * viewLeft # rotate 10 degrees around x axis - rot = pyrr.Matrix44.from_x_rotation(np.deg2rad(-10)) + rot = Matrix44.from_x_rotation(np.deg2rad(-10)) viewmat = viewmat * rot # rotate 35 degrees around y axis - rot = pyrr.Matrix44.from_y_rotation(np.deg2rad(35)) + rot = Matrix44.from_y_rotation(np.deg2rad(35)) viewmat = viewmat * rot # rotate 10 degrees around z axis - rot = pyrr.Matrix44.from_z_rotation(np.deg2rad(-8)) + rot = Matrix44.from_z_rotation(np.deg2rad(-8)) viewmat = viewmat * rot return viewmat def snap_cc_picture( self, - output_path: str, - fssurf_file: str | None = None, - overlay_file: str | None = None + output_path: Path | str, + fssurf_file: Path | str | None = None, + overlay_file: Path | str | None = None ) -> None: """Snap a picture of the corpus callosum mesh. Parameters ---------- - output_path : str + output_path : Path, str Path where to save the snapshot image. - fssurf_file : str or None, optional + fssurf_file : Path, str, optional Path to a FreeSurfer surface file to use for the snapshot. - If None, the mesh is saved to a temporary file, by default None. - overlay_file : str or None, optional + If None, the mesh is saved to a temporary file. + overlay_file : Path, str, optional Path to a FreeSurfer overlay file to use for the snapshot. - If None, the mesh is saved to a temporary file, by default None. + If None, the mesh is saved to a temporary file. Raises ------ @@ -1269,7 +1280,6 @@ def snap_cc_picture( - Ambient lighting and colorbar settings. - Thickness overlay if available. 3. Cleans up temporary files after use. - """ try: from whippersnappy.core import snap1 @@ -1282,30 +1292,29 @@ def snap_cc_picture( self.__make_parent_folder(output_path) # Skip snapshot if there are no faces if len(self.t) == 0: - print("Warning: Cannot create snapshot - no faces in mesh") + logger.warning("Cannot create snapshot - no faces in mesh") return # create temp file - if fssurf_file is None: - temp_file = tempfile.NamedTemporaryFile(suffix=".fssurf", delete=True) - self.write_fssurf(temp_file.name) + if fssurf_file: + fssurf_file = Path(fssurf_file) else: - temp_file = Path(fssurf_file) - - if overlay_file is None: - if hasattr(self, "mesh_vertex_colors"): - overlay_file = tempfile.NamedTemporaryFile(suffix=".w", delete=True) - # Write thickness values in FreeSurfer .w format - nib.freesurfer.write_morph_data(overlay_file.name, self.mesh_vertex_colors) - overlaypath = overlay_file.name - else: - overlaypath = None + fssurf_file = tempfile.NamedTemporaryFile(suffix=".fssurf", delete=True) + self.write_fssurf(fssurf_file.name) + + if overlay_file: + overlay_path: str | None = Path(overlay_file).name + elif hasattr(self, "mesh_vertex_colors"): + overlay_file = tempfile.NamedTemporaryFile(suffix=".w", delete=True) + # Write thickness values in FreeSurfer .w format + nib.freesurfer.write_morph_data(overlay_file.name, self.mesh_vertex_colors) + overlay_path = overlay_file.name else: - overlaypath = Path(overlay_file).name + overlay_path = None snap1( - temp_file.name, - overlaypath=overlaypath, + fssurf_file.name, + overlaypath=overlay_path, view=None, viewmat=self.__create_cc_viewmat(), width=3 * 500, @@ -1323,8 +1332,10 @@ def snap_cc_picture( caption_scale=0.5, ) - temp_file.close() - overlay_file.close() + if fssurf_file and hasattr(fssurf_file, "close"): + fssurf_file.close() + if overlay_file and hasattr(overlay_file, "close"): + overlay_file.close() def smooth_(self, iterations: int = 1) -> None: """Smooth the mesh while preserving the z-coordinates. @@ -1345,12 +1356,12 @@ def smooth_(self, iterations: int = 1) -> None: super().smooth_(iterations) self.v[:, 2] = z_values - def save_contours(self, output_path: str) -> None: + def save_contours(self, output_path: Path | str) -> None: """Save the contours to a CSV file. Parameters ---------- - output_path : str + output_path : Path, str Path where to save the CSV file. Notes @@ -1430,12 +1441,12 @@ def load_contours(self, input_path: str) -> None: self.contours = self.contours + [None] * (max_slices - len(self.contours)) self.start_end_idx = self.start_end_idx + [None] * (max_slices - len(self.start_end_idx)) - def save_thickness_values(self, output_path: str) -> None: + def save_thickness_values(self, output_path: Path | str) -> None: """Save thickness values to a CSV file. Parameters ---------- - output_path : str + output_path : Path, str Path where to save the CSV file. Notes @@ -1564,12 +1575,12 @@ def load_thickness_values( self.thickness_values = new_thickness_values @staticmethod - def __make_parent_folder(filename: str) -> None: + def __make_parent_folder(filename: Path | str) -> None: """Create the parent folder for a file if it doesn't exist. Parameters ---------- - filename : str + filename : Path, str Path to the file whose parent folder should be created. Notes @@ -1577,13 +1588,12 @@ def __make_parent_folder(filename: str) -> None: Creates parent directory with parents=False to avoid creating multiple levels of directories unintentionally. """ - output_folder = Path(filename).parent - output_folder.mkdir(parents=False, exist_ok=True) + Path(filename).parent.mkdir(parents=False, exist_ok=True) def to_fs_coordinates( self, vox2ras_tkr: np.ndarray, - vox_size: tuple[float, float, float] + vox_size: tuple[float, float, float], ) -> None: """Convert mesh coordinates to FreeSurfer coordinate system. @@ -1591,7 +1601,7 @@ def to_fs_coordinates( ---------- vox2ras_tkr : np.ndarray 4x4 voxel to RAS tkr-space transformation matrix. - vox_size : tuple[float, float, float] + vox_size : 3-tuple of floats Voxel size in millimeters (x, y, z). Notes @@ -1615,7 +1625,6 @@ def to_fs_coordinates( # flip SI v_vox[:, 1] = -v_vox[:, 1] - #v_vox_test = np.round(v_vox).astype(int) ## write volume for debugging # contour_img = np.zeros(orig.shape) @@ -1627,25 +1636,15 @@ def to_fs_coordinates( # https://surfer.nmr.mgh.harvard.edu/fswiki/CoordinateSystems self.v = (vox2ras_tkr @ np.concatenate([v_vox, np.ones((self.v.shape[0], 1))], axis=1).T).T[:, :3] self.v = self.v * vox_size[0] - - - - - - def write_fssurf(self, filename: str) -> None: + def write_fssurf(self, filename: Path | str) -> None: """Write the mesh to a FreeSurfer surface file. Parameters ---------- - filename : str + filename : Path, str Path where to save the FreeSurfer surface file. - Returns - ------- - None - Returns the result of the parent class's write_fssurf method. - Notes ----- Creates parent directory if needed before writing the file. @@ -1653,19 +1652,14 @@ def write_fssurf(self, filename: str) -> None: self.__make_parent_folder(filename) return super().write_fssurf(filename) - def write_overlay(self, filename: str) -> None: + def write_overlay(self, filename: Path | str) -> None: """Write the thickness values as a FreeSurfer overlay file. Parameters ---------- - filename : str + filename : Path, str Path where to save the overlay file. - Returns - ------- - None - Returns the result of writing the morph data using nibabel. - Notes ----- Creates parent directory if needed before writing the file. @@ -1673,12 +1667,12 @@ def write_overlay(self, filename: str) -> None: self.__make_parent_folder(filename) return nib.freesurfer.write_morph_data(filename, self.mesh_vertex_colors) - def save_thickness_measurement_points(self, filename: str) -> None: + def save_thickness_measurement_points(self, filename: Path | str) -> None: """Write the thickness measurement points to a CSV file. Parameters ---------- - filename : str + filename : Path, str Path where to save the CSV file. Notes diff --git a/CorpusCallosum/shape/cc_metrics.py b/CorpusCallosum/shape/cc_metrics.py index 7bd62886..491191d7 100644 --- a/CorpusCallosum/shape/cc_metrics.py +++ b/CorpusCallosum/shape/cc_metrics.py @@ -25,8 +25,8 @@ def calculate_cc_index(cc_contour: np.ndarray) -> float: Returns ------- - float - Sum of thicknesses at three measurement points divided by AP length. + cc_index : float + The CC index, which is the sum of thicknesses at three measurement points divided by AP length. """ # Get anterior and posterior points anterior_idx = np.argmin(cc_contour[0]) # Leftmost point @@ -64,14 +64,9 @@ def get_intersections(start_point: np.ndarray, direction: np.ndarray) -> np.ndar signs = np.sign(dots) sign_changes = np.where(np.diff(signs))[0] - intersections = [] - for idx in sign_changes: - # Linear interpolation between points - t = -dots[idx] / (dots[idx + 1] - dots[idx]) - intersection = cc_contour[:, idx] + t * (cc_contour[:, idx + 1] - cc_contour[:, idx]) - intersections.append(intersection) - - return np.array(intersections) + # Linear interpolation between points + t = -dots[sign_changes] / (dots[sign_changes + 1] - dots[sign_changes]) + return cc_contour[:, sign_changes] + t * (cc_contour[:, sign_changes + 1] - cc_contour[:, sign_changes]) # Get three measurements most_anterior_pt = cc_contour[:, anterior_idx] @@ -98,7 +93,7 @@ def get_intersections(start_point: np.ndarray, direction: np.ndarray) -> np.ndar posterior_distance = np.linalg.norm(anterior_intersections[-1] - anterior_intersections[-2]) top_distance = np.linalg.norm(middle_ints[0] - middle_ints[1]) - index = (anterior_distance + posterior_distance + top_distance) / ap_distance + cc_index = (anterior_distance + posterior_distance + top_distance) / ap_distance # fig, ax = plt.subplots(figsize=(8, 6)) @@ -150,4 +145,4 @@ def get_intersections(start_point: np.ndarray, direction: np.ndarray) -> np.ndar # plt.axis('off') # plt.show() - return index + return cc_index diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/cc_postprocessing.py index f5ad2054..cdda1b96 100644 --- a/CorpusCallosum/shape/cc_postprocessing.py +++ b/CorpusCallosum/shape/cc_postprocessing.py @@ -11,17 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import multiprocessing +import concurrent.futures from pathlib import Path +from typing import Literal, get_args import numpy as np +from numpy import typing as npt import FastSurferCNN.utils.logging as logging -from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE, SUBSEGMENT_LABELS -from CorpusCallosum.data.read_write import run_in_background +from CorpusCallosum.data.constants import CC_LABEL, FSAVERAGE_MIDDLE, SUBSEGMENT_LABELS from CorpusCallosum.shape.cc_endpoint_heuristic import get_endpoints -from CorpusCallosum.shape.cc_mesh import CC_Mesh +from CorpusCallosum.shape.cc_mesh import CCMesh from CorpusCallosum.shape.cc_metrics import calculate_cc_index from CorpusCallosum.shape.cc_subsegment_contour import ( get_primary_eigenvector, @@ -33,9 +33,12 @@ from CorpusCallosum.shape.cc_thickness import cc_thickness, convert_to_ras from CorpusCallosum.utils.utils import HiddenPrints from CorpusCallosum.visualization.visualization import plot_contours +from FastSurferCNN.utils.common import SubjectDirectory, update_docstring +from FastSurferCNN.utils.common import thread_executor as executor -logger = logging.get_logger(__name__) +SubdivisionMethod = Literal["shape", "vertical", "angular", "eigenvector"] +logger = logging.get_logger(__name__) # assert LIA orientation LIA_ORIENTATION = np.zeros((3,3)) @@ -44,20 +47,28 @@ LIA_ORIENTATION[2,1] = -1 -def create_visualization(subdivision_method: str, result: dict, midslices_data: np.ndarray, - output_image_path: str | Path, ac_coords: np.ndarray, - pc_coords: np.ndarray, vox_size: float, title_suffix: str = "") -> multiprocessing.Process: +@update_docstring(SubdivisionMethod=str(get_args(SubdivisionMethod))[1:-1]) +def async_create_visualization( + subdivision_method: SubdivisionMethod, + result: dict, + midslices_data: np.ndarray, + output_image_path: str | Path, + ac_coords: np.ndarray, + pc_coords: np.ndarray, + vox_size: float, + title_suffix: str = "", +) -> concurrent.futures.Future: """Create visualization plots based on subdivision method. Parameters ---------- - subdivision_method : str + subdivision_method : {SubdivisionMethod} The subdivision method being used. result : dict Dictionary containing processing results with split_contours. midslices_data : np.ndarray Slice data for visualization. - output_image_path : str or Path + output_image_path : Path, str Path to save visualization. ac_coords : np.ndarray AC coordinates. @@ -73,23 +84,22 @@ def create_visualization(subdivision_method: str, result: dict, midslices_data: multiprocessing.Process Process object for background execution. """ - title = f'CC Subsegmentation by {subdivision_method} {title_suffix}' + title = f"CC Subsegmentation by {subdivision_method} {title_suffix}" args_dict = { - 'debug': True, - 'transformed': midslices_data, - 'split_contours': result['split_contours'], - 'midline_equidistant': result['midline_equidistant'], - 'levelpaths': result['levelpaths'], - 'output_path': output_image_path, - 'ac_coords': ac_coords, - 'pc_coords': pc_coords, - 'vox_size': vox_size, - 'title': title, + "debug": True, + "transformed": midslices_data, + "split_contours": result["split_contours"], + "midline_equidistant": result["midline_equidistant"], + "levelpaths": result["levelpaths"], + "output_path": output_image_path, + "ac_coords": ac_coords, + "pc_coords": pc_coords, + "vox_size": vox_size, + "title": title, } - return run_in_background(plot_contours, **args_dict) - + return executor().submit(plot_contours, **args_dict) def create_slice_affine(temp_seg_affine: np.ndarray, slice_idx: int, fsaverage_middle: int) -> np.ndarray: @@ -114,7 +124,214 @@ def create_slice_affine(temp_seg_affine: np.ndarray, slice_idx: int, fsaverage_m return slice_affine -def process_slice( +@update_docstring(SubdivisionMethod=str(get_args(SubdivisionMethod))[1:-1]) +def recon_cc_surf_measures_multi( + segmentation: np.ndarray, + slice_selection: str, + temp_seg_affine: np.ndarray, + midslices: np.ndarray, + ac_coords: np.ndarray, + pc_coords: np.ndarray, + num_thickness_points: int, + subdivisions: list[float], + subdivision_method: SubdivisionMethod, + contour_smoothing: float, + subject_dir: SubjectDirectory, + qc_image_path: str | None = None, + vox_size: tuple[float, float, float] | None = None, + vox2ras_tkr: np.ndarray | None = None, +) -> tuple[list, list[concurrent.futures.Future]]: + """Surface reconstruction and metrics computation of corpus callosum slices based on selection mode. + + Parameters + ---------- + segmentation : np.ndarray + 3D segmentation array. + slice_selection : str + Which slices to process ('middle', 'all', or slice number). + temp_seg_affine : np.ndarray + Base affine transformation matrix. + midslices : np.ndarray + Array of mid-sagittal slices. + ac_coords : np.ndarray + Anterior commissure coordinates. + pc_coords : np.ndarray + Posterior commissure coordinates. + num_thickness_points : int + Number of points for thickness estimation. + subdivisions : list[float] + List of fractions for anatomical subdivisions. + subdivision_method : {SubdivisionMethod} + Method for contour subdivision. + contour_smoothing : float + Gaussian sigma for contour smoothing. + subject_dir : SubjectDirectory + The SubjectDirectory object managing file names in the subject directory. + qc_image_path : Path, str, optional + Path for QC visualization image. + vox_size : 3-tuple of floats, optional + Voxel size in millimeters (x, y, z). + vox2ras_tkr : np.ndarray, optional + Voxel to RAS tkr-space transformation matrix. + + Returns + ------- + list + List of slice processing results. + list[concurrent.futures.Future] + List of background IO processes. + """ + slice_results = [] + io_futures = [] + + if slice_selection == "middle": + cc_mesh = CCMesh(num_slices=1) + cc_mesh.set_acpc_coords(ac_coords, pc_coords) + cc_mesh.set_resolution(vox_size[0]) + + # Process only the middle slice + slice_idx = segmentation.shape[0] // 2 + slice_affine = create_slice_affine(temp_seg_affine, slice_idx, FSAVERAGE_MIDDLE) + + result, contour_with_thickness, *endpoint_idxs = recon_cc_surf_measure( + segmentation, + slice_idx, + ac_coords, + pc_coords, + slice_affine, + num_thickness_points, + subdivisions, + subdivision_method, + contour_smoothing, + vox_size[0], + ) + + cc_mesh.add_contour(0, *contour_with_thickness, start_end_idx=endpoint_idxs) + + if result is not None and qc_image_path is not None: + slice_results.append(result) + # Create visualization + logger.info(f"Saving segmentation qc image to {qc_image_path}") + io_futures.append(async_create_visualization( + subdivision_method, + result, + midslices, + qc_image_path, + ac_coords, + pc_coords, + vox_size[0], + )) + else: + num_slices = segmentation.shape[0] + cc_mesh = CCMesh(num_slices=num_slices) + cc_mesh.set_acpc_coords(ac_coords, pc_coords) + cc_mesh.set_resolution(vox_size[0]) + + # Process multiple slices or specific slice + if slice_selection == "all": + start_slice = 0 + end_slice = segmentation.shape[0] + else: # specific slice number + slice_idx = int(slice_selection) + start_slice = slice_idx + end_slice = slice_idx + 1 + + for slice_idx in range(start_slice, end_slice): + logger.info(f"Calculating CC measurements for slice {slice_idx+1} of {end_slice-start_slice}") + + # Update affine for this slice + slice_affine = create_slice_affine(temp_seg_affine, slice_idx, FSAVERAGE_MIDDLE) + + # Process this slice + result, contour_with_thickness, *endpoint_idxs = recon_cc_surf_measure( + segmentation, + slice_idx, + ac_coords, + pc_coords, + slice_affine, + num_thickness_points, + subdivisions, + subdivision_method, + contour_smoothing, + vox_size[0], + ) + + # insert + cc_mesh.add_contour(slice_idx, *contour_with_thickness, start_end_idx=endpoint_idxs) + + if result is not None: + slice_results.append(result) + + if logger.getEffectiveLevel() <= logging.INFO and subject_dir.has_attribute("cc_qc_image"): + qc_img = subject_dir.filename_by_attribute("cc_qc_image") + if logger.getEffectiveLevel() <= logging.DEBUG: + qc_img = (qc_img.parent / f"{qc_img.stem}_slice_{slice_idx}{qc_img.suffix}").with_suffix(".png") + + if logger.getEffectiveLevel() <= logging.DEBUG or slice_idx == num_slices // 2: + logger.info(f"Saving segmentation qc image to {qc_img}") + + current_slice_in_volume = midslices.shape[0] // 2 - num_slices // 2 + slice_idx + # Create visualization for this slice + io_futures.append(async_create_visualization( + subdivision_method, + result, + midslices[current_slice_in_volume:current_slice_in_volume+1], + qc_img, + ac_coords, + pc_coords, + vox_size[0], + f" (Slice {slice_idx})", + )) + + if subject_dir.has_attribute("save_template_dir"): + template_dir = subject_dir.filename_by_attribute("save_template_dir") + # ensure directory exists + template_dir.mkdir(parents=True, exist_ok=True) + logger.info("Saving template files (contours.txt, thickness_values.txt, " + f"thickness_measurement_points.txt) to {template_dir}") + cc_mesh.save_contours(template_dir / "contours.txt") + cc_mesh.save_thickness_values(template_dir / "thickness_values.txt") + cc_mesh.save_thickness_measurement_points(template_dir / "thickness_measurement_points.txt") + + + if len(cc_mesh.contours) > 1 and subject_dir.has_attribute("cc_html"): + cc_mesh.fill_thickness_values() + cc_mesh.create_mesh() + cc_mesh.smooth_(1) + logger.info(f"Saving CC 3D visualization to {subject_dir.filename_by_attribute('cc_html')}") + cc_mesh.plot_mesh(output_path=subject_dir.filename_by_attribute("cc_html"), show_mesh_edges=True) + + if subject_dir.has_attribute("cc_mesh"): + vtk_file_path = subject_dir.filename_by_attribute("cc_mesh") + logger.info(f"Saving vtk file to {vtk_file_path}") + cc_mesh.write_vtk(vtk_file_path) + + cc_mesh.to_fs_coordinates(vox2ras_tkr=vox2ras_tkr, vox_size=vox_size) + if subject_dir.has_attribute("overlay_file"): + overlay_file_path = subject_dir.filename_by_attribute("overlay_file") + logger.info(f"Saving overlay file to {overlay_file_path}") + cc_mesh.write_overlay(overlay_file_path) + + if subject_dir.has_attribute("cc_surf_file"): + surf_file_path = subject_dir.filename_by_attribute("cc_surf_file") + logger.info(f"Saving surf file to {surf_file_path}") + cc_mesh.write_fssurf(surf_file_path) + + if subject_dir.has_attribute("thickness_image"): + thickness_image_path = subject_dir.filename_by_attribute("thickness_image") + logger.info(f"Saving thickness image to {thickness_image_path}") + with HiddenPrints(): + cc_mesh.snap_cc_picture(thickness_image_path) + + + if not slice_results: + logger.error("Error: No valid slices were found for postprocessing") + raise ValueError("No valid slices were found for postprocessing") + + return slice_results, io_futures + + +def recon_cc_surf_measure( segmentation: np.ndarray, slice_idx: int, ac_coords: np.ndarray, @@ -122,11 +339,11 @@ def process_slice( affine: np.ndarray, num_thickness_points: int, subdivisions: list[float], - subdivision_method: str, + subdivision_method: SubdivisionMethod, contour_smoothing: float, vox_size: float -) -> dict | None: - """Process a single slice for corpus callosum measurements. +) -> tuple[dict[str, float | int | np.ndarray | list[float]], np.ndarray, int, int]: + """Reconstruct surfaces and compute measures for a single slice for the corpus callosum. Parameters ---------- @@ -144,7 +361,7 @@ def process_slice( Number of points for thickness estimation. subdivisions : list[float] List of fractions for anatomical subdivisions. - subdivision_method : str + subdivision_method : SubdivisionMethod Method for contour subdivision ('shape', 'vertical', 'angular', or 'eigenvector'). contour_smoothing : float Gaussian sigma for contour smoothing. @@ -153,7 +370,7 @@ def process_slice( Returns ------- - dict | None + dict of measures Dictionary containing measurements if successful, including: - cc_index : float - Corpus callosum shape index. - circularity : float - Shape circularity measure. @@ -169,7 +386,9 @@ def process_slice( - levelpaths : list[np.ndarray] - Paths for thickness measurements. - thickness_measurement_points : np.ndarray - Points where thickness was measured. - slice_index : int - Index of the processed slice. - Returns None if no CC is found in the slice. + contour_with_thickness : np.ndarray + anterior_endpoint_index : int + posterior_endpoint_index : int Raises ------ @@ -186,51 +405,52 @@ def process_slice( 5. Generates visualization data. """ - - cc_mask_slice = segmentation[slice_idx] == 192 + cc_mask_slice: npt.NDArray[bool] = segmentation[slice_idx] == CC_LABEL if not np.any(cc_mask_slice): - raise ValueError(f'No CC found in slice {slice_idx}') - - - contour, anterior_endpoint_idx, posterior_endpoint_idx = get_endpoints(cc_mask_slice, ac_coords, pc_coords, - vox_size, - return_coordinates=False, - contour_smoothing=contour_smoothing) + raise ValueError(f"No CC found in slice {slice_idx}") + + contour, *endpoint_idxs = get_endpoints( + cc_mask_slice, + ac_coords, + pc_coords, + vox_size, + return_coordinates=False, + contour_smoothing=contour_smoothing, + ) contour_1mm = convert_to_ras(contour, affine) - (midline_length, thickness, curvature, midline_equidistant, levelpaths, - contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx) = cc_thickness(contour_1mm.T, - anterior_endpoint_idx, - posterior_endpoint_idx, - n_points=num_thickness_points) - + midline_len, thickness, curvature, midline_equi, levelpaths, contour_with_thickness, *endpoint_idxs = cc_thickness( + contour_1mm.T, + *endpoint_idxs, + n_points=num_thickness_points, + ) + thickness_profile = [ np.sum(np.sqrt(np.diff(np.array(levelpath[:,:2]), axis=0)**2), axis=0) for levelpath in levelpaths ] thickness_profile = np.linalg.norm(np.array(thickness_profile),axis=1) - contour_acpc, ac_pt_acpc, pc_pt_acpc, rotate_back_acpc = transform_to_acpc_standard(contour_1mm, - contour_1mm[:,anterior_endpoint_idx], - contour_1mm[:,posterior_endpoint_idx]) + acpc_contour_coords = contour_1mm[:, list(endpoint_idxs)].T + contour_acpc, ac_pt_acpc, pc_pt_acpc, rotate_back_acpc = transform_to_acpc_standard( + contour_1mm, + *acpc_contour_coords, + ) cc_index = calculate_cc_index(contour_acpc) # Apply different subdivision methods based on user choice if subdivision_method == "shape": - areas, split_contours = subsegment_midline_orthogonal(midline_equidistant, subdivisions, - contour_1mm, plot=False) - split_contours = [transform_to_acpc_standard(split_contour, - contour_1mm[:,anterior_endpoint_idx], - contour_1mm[:,posterior_endpoint_idx])[0] - for split_contour in split_contours] + areas, split_contours = subsegment_midline_orthogonal(midline_equi, subdivisions, contour_1mm, plot=False) + split_contours = [transform_to_acpc_standard(split_contour, *acpc_contour_coords)[0] + for split_contour in split_contours] elif subdivision_method == "vertical": areas, split_contours = subdivide_contour(contour_acpc, subdivisions, plot=False) elif subdivision_method == "angular": if not np.allclose(np.diff(subdivisions), np.diff(subdivisions)[0]): - logger.error('Error: Angular subdivision method (Hampel) only supports equidistant subdivision, ' - f'but got: {subdivisions}') - return None - areas, split_contours = hampel_subdivide_contour(contour_acpc, num_rays=len(subdivisions), plot=False) + logger.error("Error: Angular subdivision method (Hampel) only supports equidistant subdivision, " + f"but got: {subdivisions}. No measures are computed.") + return {}, contour_with_thickness, *endpoint_idxs + areas, split_contours = hampel_subdivide_contour(contour_acpc, num_rays=len(subdivisions), plot=False) elif subdivision_method == "eigenvector": pt0, pt1 = get_primary_eigenvector(contour_acpc) contour_eigen, _, _, rotate_back_eigen = transform_to_acpc_standard(contour_acpc, pt0, pt1) @@ -246,250 +466,22 @@ def process_slice( # Transform split contours back to original space split_contours = [rotate_back_acpc(split_contour) for split_contour in split_contours] - return { - 'cc_index': cc_index, - 'circularity': circularity, - 'areas': areas, - 'midline_length': midline_length, - 'thickness': thickness, - 'curvature': curvature, - 'thickness_profile': thickness_profile, - 'total_area': total_area, - 'total_perimeter': total_perimeter, - 'split_contours': split_contours, - 'midline_equidistant': midline_equidistant, - 'levelpaths': levelpaths, - 'slice_index': slice_idx - }, contour_with_thickness, anterior_endpoint_idx, posterior_endpoint_idx - - -def process_slices( - segmentation: np.ndarray, - slice_selection: str, - temp_seg_affine: np.ndarray, - midslices: np.ndarray, - ac_coords: np.ndarray, - pc_coords: np.ndarray, - num_thickness_points: int, - subdivisions: list[float], - subdivision_method: str, - contour_smoothing: float, - qc_image_path: str | None = None, - one_debug_image: bool = False, - thickness_image_path: str | None = None, - vox_size: tuple[float, float, float] | None = None, - vox2ras_tkr: np.ndarray | None = None, - save_template: str | Path | None = None, - surf_file_path: str | None = None, - overlay_file_path: str | None = None, - cc_html_path: str | None = None, - vtk_file_path: str | None = None, - verbose: bool = False -) -> tuple[list, list]: - """Process corpus callosum slices based on selection mode. - - Parameters - ---------- - segmentation : np.ndarray - 3D segmentation array. - slice_selection : str - Which slices to process ('middle', 'all', or slice number). - temp_seg_affine : np.ndarray - Base affine transformation matrix. - midslices : np.ndarray - Array of mid-sagittal slices. - ac_coords : np.ndarray - Anterior commissure coordinates. - pc_coords : np.ndarray - Posterior commissure coordinates. - num_thickness_points : int - Number of points for thickness estimation. - subdivisions : list[float] - List of fractions for anatomical subdivisions. - subdivision_method : str - Method for contour subdivision. - contour_smoothing : float - Gaussian sigma for contour smoothing. - qc_image_path : str or None, optional - Path for QC visualization image, by default None. - one_debug_image : bool, optional - Whether to save only one debug image, by default False. - thickness_image_path : str or None, optional - Path for thickness visualization image, by default None. - vox_size : tuple[float, float, float] or None, optional - Voxel size in millimeters (x, y, z), by default None. - vox2ras_tkr : np.ndarray or None, optional - Voxel to RAS tkr-space transformation matrix, by default None. - save_template : str or Path or None, optional - Directory path where to save template files, by default None. - surf_file_path : str or None, optional - Path to save surface file, by default None. - overlay_file_path : str or None, optional - Path to save overlay file, by default None. - cc_html_path : str or None, optional - Path to save HTML visualization, by default None. - vtk_file_path : str or None, optional - Path to save VTK file, by default None. - verbose : bool, optional - Whether to print progress information, by default False. - - Returns - ------- - list - List of slice processing results. - list - List of background IO processes. - """ - slice_results = [] - IO_processes = [] - - if slice_selection == "middle": - cc_mesh = CC_Mesh(num_slices=1) - cc_mesh.set_acpc_coords(ac_coords, pc_coords) - cc_mesh.set_resolution(vox_size[0]) - - # Process only the middle slice - slice_idx = segmentation.shape[0] // 2 - slice_affine = create_slice_affine(temp_seg_affine, slice_idx, FSAVERAGE_MIDDLE) - - (result, contour_with_thickness, - anterior_endpoint_idx, posterior_endpoint_idx) = process_slice(segmentation, - slice_idx, - ac_coords, - pc_coords, - slice_affine, - num_thickness_points, - subdivisions, - subdivision_method, - contour_smoothing, - vox_size[0]) - - cc_mesh.add_contour(0, - contour_with_thickness[0], - contour_with_thickness[1], - start_end_idx=(anterior_endpoint_idx, posterior_endpoint_idx)) - - if result is not None and qc_image_path is not None: - slice_results.append(result) - # Create visualization - if verbose: - logger.info(f"Saving segmentation qc image to {qc_image_path}") - IO_processes.append(create_visualization(subdivision_method, result, midslices, - qc_image_path, ac_coords, pc_coords, vox_size[0])) - else: - num_slices = segmentation.shape[0] - cc_mesh = CC_Mesh(num_slices=num_slices) - cc_mesh.set_acpc_coords(ac_coords, pc_coords) - cc_mesh.set_resolution(vox_size[0]) - - # Process multiple slices or specific slice - if slice_selection == "all": - start_slice = 0 - end_slice = segmentation.shape[0] - else: # specific slice number - slice_idx = int(slice_selection) - start_slice = slice_idx - end_slice = slice_idx + 1 - - for slice_idx in range(start_slice, end_slice): - if verbose: - logger.info(f"Calculating CC measurements for slice {slice_idx+1} of {end_slice-start_slice}") - - # Update affine for this slice - slice_affine = create_slice_affine(temp_seg_affine, slice_idx, FSAVERAGE_MIDDLE) - - # Process this slice - (result, contour_with_thickness, - anterior_endpoint_idx, posterior_endpoint_idx) = process_slice(segmentation, slice_idx, - ac_coords, pc_coords, - slice_affine, num_thickness_points, - subdivisions, subdivision_method, - contour_smoothing, - vox_size[0]) - - # insert - cc_mesh.add_contour(slice_idx, - contour_with_thickness[0], - contour_with_thickness[1], - start_end_idx=(anterior_endpoint_idx, posterior_endpoint_idx)) - - if result is not None: - slice_results.append(result) - - if (one_debug_image and slice_idx == num_slices // 2) or not one_debug_image: - if not one_debug_image: - qc_path_base, qc_path_ext = str(qc_image_path).rsplit('.', 1) - qc_path_with_postfix = f"{qc_path_base}_slice_{slice_idx}" - - qc_output_path_slice = Path(f"{qc_path_with_postfix}.{qc_path_ext}") - qc_output_path_slice = qc_output_path_slice.with_suffix('.png') - else: - qc_output_path_slice = qc_image_path - - if verbose: - logger.info(f"Saving segmentation qc image to {qc_output_path_slice}") - - current_slice_in_volume = midslices.shape[0] // 2 - num_slices // 2 + slice_idx - # Create visualization for this slice - IO_processes.append(create_visualization(subdivision_method, result, - midslices[current_slice_in_volume:current_slice_in_volume+1], - qc_output_path_slice, ac_coords, pc_coords, - vox_size[0], f' (Slice {slice_idx})')) - - if save_template is not None: - # Convert to Path object and ensure directory exists - template_dir = Path(save_template) - template_dir.mkdir(parents=True, exist_ok=True) - if verbose: - logger.info("Saving template files (contours.txt, thickness_values.txt, " - f"thickness_measurement_points.txt) to {template_dir}") - cc_mesh.save_contours(str(template_dir / 'contours.txt')) - cc_mesh.save_thickness_values(str(template_dir / 'thickness_values.txt')) - cc_mesh.save_thickness_measurement_points(str(template_dir / 'thickness_measurement_points.txt')) - - - if len(cc_mesh.contours) > 1 and thickness_image_path is not None: - cc_mesh.fill_thickness_values() - cc_mesh.create_mesh() - cc_mesh.smooth_(1) - if verbose: - logger.info(f"Saving CC 3D visualization to {cc_html_path}") - cc_mesh.plot_mesh(output_path=str(cc_html_path), show_mesh_edges=True) - - if vtk_file_path is not None: - if verbose: - logger.info(f"Saving vtk file to {vtk_file_path}") - cc_mesh.write_vtk(str(vtk_file_path)) - - - cc_mesh.to_fs_coordinates(vox2ras_tkr=vox2ras_tkr, vox_size=vox_size) - - if overlay_file_path is not None: - if verbose: - logger.info(f"Saving overlay file to {overlay_file_path}") - cc_mesh.write_overlay(str(overlay_file_path)) - - if surf_file_path is not None: - if verbose: - logger.info(f"Saving surf file to {surf_file_path}") - cc_mesh.write_fssurf(str(surf_file_path)) - - - - if thickness_image_path is not None: - if verbose: - logger.info(f"Saving thickness image to {thickness_image_path}") - with HiddenPrints(): - cc_mesh.snap_cc_picture(str(thickness_image_path)) - - - if not slice_results: - logger.error("Error: No valid slices were found for postprocessing") - raise ValueError("No valid slices were found for postprocessing") - - return slice_results, IO_processes - - + measures = { + "cc_index": cc_index, + "circularity": circularity, + "areas": areas, + "midline_length": midline_len, + "thickness": thickness, + "curvature": curvature, + "thickness_profile": thickness_profile, + "total_area": total_area, + "total_perimeter": total_perimeter, + "split_contours": split_contours, + "midline_equidistant": midline_equi, + "levelpaths": levelpaths, + "slice_index": slice_idx + } + return measures, contour_with_thickness, *endpoint_idxs def vectorized_line_test(coords_x: np.ndarray, coords_y: np.ndarray, @@ -622,7 +614,7 @@ def make_subdivision_mask( for s in subdivision_segments: if len(s) != 2: - logger.error(f'Subdivision segment {s} has {len(s)} points, expected 2') + logger.error(f"Subdivision segment {s} has {len(s)} points, expected 2") # Create coordinate grids for all points in the slice rows, cols = slice_shape @@ -648,25 +640,23 @@ def make_subdivision_mask( # Debug visualization (optional) # import matplotlib.pyplot as plt # fig, ax = plt.subplots(figsize=(10, 8)) - # ax.imshow(subdivision_mask, cmap='tab10') - # ax.plot([line_start[0], line_end[0]], [line_start[1], line_end[1]], 'r-', linewidth=2) - # ax.set_title(f'After subdivision line {segment_idx}') + # ax.imshow(subdivision_mask, cmap="tab10") + # ax.plot([line_start[0], line_end[0]], [line_start[1], line_end[1]], "r-", linewidth=2) + # ax.set_title(f"After subdivision line {segment_idx}") # plt.show() return subdivision_mask -def check_area_changes(contours: list[np.ndarray], threshold: float = 0.3, verbose: bool = False) -> bool: +def check_area_changes(contours: list[np.ndarray], threshold: float = 0.3) -> bool: """Check for large changes between consecutive CC areas and issue warnings. Parameters ---------- contours : list[np.ndarray] List of contours. - threshold : float, optional - Threshold for relative change, by default 0.3 (30%). - verbose : bool, optional - Whether to print warnings, by default False. + threshold : float, default=0.3 + Threshold for relative change. Returns ------- @@ -674,29 +664,25 @@ def check_area_changes(contours: list[np.ndarray], threshold: float = 0.3, verbo True if no large area changes are detected, False otherwise. """ - areas = [np.sum(np.sqrt(np.sum((np.diff(contour, axis=0))**2, axis=1))) for contour in contours] + areas = np.asarray([np.sum(np.sqrt(np.sum((np.diff(contour, axis=0))**2, axis=1))) for contour in contours]) assert len(areas) > 1, "At least two areas are required to check for area changes" - - for i in range(len(areas) - 1): - if areas[i] == 0 and areas[i+1] == 0: - continue # Skip if both areas are zero - - if areas[i] == 0 or areas[i+1] == 0: - # One area is zero, the other is not - this is a 100% change - if verbose: - logger.warning(f"Large area change detected: area {i+1} = {areas[i]:.2f} mm², " - f"area {i+2} = {areas[i+1]:.2f} mm² (one area is zero)") - return False - - # Calculate relative change - relative_change = abs(areas[i+1] - areas[i]) / areas[i] - - if relative_change > threshold: - percent_change = relative_change * 100 - if verbose: - logger.warning(f"Large corpus callosum area change between slices detected: " - f"area {i+1} = {areas[i]:.2f} mm², " - f"area {i+2} = {areas[i+1]:.2f} mm² ({percent_change:.1f}% change)") - return False + + if np.any(areas == 0): + # One area is zero, the other is not - this is a 100% change + logger.warning(f"Areas {np.where(areas == 0)[0].tolist()} are zero mm²") + return False + + # Calculate relative change + relative_change = np.abs(np.diff(areas)) / areas[:-1] + + if np.any(where_change := relative_change > threshold): + indices = np.where(where_change)[0] + percent_change = relative_change[where_change] * 100 + logger.info( + f"Large corpus callosum area change after slices {indices.tolist()} detected: " + + ", ".join(f"areas {(i,i+1)} = ({areas[i]:.2f},{areas[i+1]:.2f}) mm² ({p:.1f}% change)" + for i, p in zip(indices, percent_change, strict=True)) + ) + return False return True \ No newline at end of file diff --git a/CorpusCallosum/shape/cc_subsegment_contour.py b/CorpusCallosum/shape/cc_subsegment_contour.py index ced67946..16ab7dc2 100644 --- a/CorpusCallosum/shape/cc_subsegment_contour.py +++ b/CorpusCallosum/shape/cc_subsegment_contour.py @@ -13,11 +13,15 @@ # limitations under the License. from collections.abc import Callable +from typing import TypeVar import matplotlib.pyplot as plt import numpy as np +from numpy import typing as npt from scipy.spatial import ConvexHull +_TS = TypeVar("_TS", bound=np.number) + def minimum_bounding_rectangle(points): """Find the smallest bounding rectangle for a set of points. @@ -80,7 +84,7 @@ def minimum_bounding_rectangle(points): return rval -def get_area_from_subsegments(split_contours): +def calc_subsegment_area(split_contours: list[npt.NDArray[_TS]]) -> npt.NDArray[_TS]: """Calculate area of each subsegment using the shoelace formula. Parameters @@ -90,18 +94,14 @@ def get_area_from_subsegments(split_contours): Returns ------- - np.ndarray + subsegment_areas : np.ndarray Array containing the area of each subsegment. """ # calculate area of each split contour using the shoelace formula - areas = [np.abs(np.trapz(split_contour[1], split_contour[0])) for split_contour in split_contours] - area_out = np.zeros(len(areas)) - for i in range(len(areas)): - if i == len(areas) - 1: - area_out[i] = areas[i] - else: - area_out[i] = areas[i] - areas[i + 1] - return area_out + areas = np.abs([np.trapz(split_contour[1], split_contour[0]) for split_contour in split_contours]) + if len(areas) == 1: + return np.asarray(areas[0]) + return np.ediff1d(np.asarray(areas)[::-1], to_end=areas[-1]) def subsegment_midline_orthogonal(midline, area_weights, contour, plot=True, ax=None, extremes=None): @@ -124,13 +124,10 @@ def subsegment_midline_orthogonal(midline, area_weights, contour, plot=True, ax= Returns ------- + subsegment_area : list[np.ndarray] split_contours : list[np.ndarray] List of contour arrays for each subsegment. - split_points : np.ndarray - Array of split points. - edge_directions : np.ndarray - Array of edge directions at split points. - + """ # get points after midline length of splits @@ -319,7 +316,7 @@ def subsegment_midline_orthogonal(midline, area_weights, contour, plot=True, ax= ax.axis("equal") plt.show() - return get_area_from_subsegments(split_contours), split_contours + return calc_subsegment_area(split_contours), split_contours def hampel_subdivide_contour(contour, num_rays, plot=False, ax=None): @@ -458,7 +455,7 @@ def hampel_subdivide_contour(contour, num_rays, plot=False, ax=None): ax.axis("equal") plt.show() - return get_area_from_subsegments(split_contours), split_contours + return calc_subsegment_area(split_contours), split_contours def subdivide_contour( @@ -800,7 +797,7 @@ def subdivide_contour( ax.axis("equal") plt.show() - return get_area_from_subsegments(split_contours), split_contours + return calc_subsegment_area(split_contours), split_contours def transform_to_acpc_standard(contour_ras, ac_pt_ras, pc_pt_ras): diff --git a/CorpusCallosum/shape/cc_thickness.py b/CorpusCallosum/shape/cc_thickness.py index bb6bf7ba..a8b0d6e7 100644 --- a/CorpusCallosum/shape/cc_thickness.py +++ b/CorpusCallosum/shape/cc_thickness.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Literal, overload import numpy as np import scipy.interpolate @@ -44,11 +45,19 @@ def compute_curvature(path: np.ndarray) -> np.ndarray: return angle_diffs +@overload +def convert_to_ras(contour: np.ndarray, vox2ras_matrix: np.ndarray, get_parameters: Literal[False] = False) \ + -> np.ndarray: ... + +@overload +def convert_to_ras(contour: np.ndarray, vox2ras_matrix: np.ndarray, get_parameters: Literal[True]) \ + -> tuple[np.ndarray, bool, bool, bool]: ... + def convert_to_ras( contour: np.ndarray, vox2ras_matrix: np.ndarray, - get_parameters: bool = False -) -> np.ndarray | tuple[np.ndarray, bool, bool, bool]: + return_parameters: bool = False +): """Convert contour coordinates from voxel space to RAS space. Parameters @@ -57,19 +66,19 @@ def convert_to_ras( Array of shape (2, N) or (3, N) containing contour coordinates. vox2ras_matrix : np.ndarray 4x4 voxel to RAS transformation matrix. - get_parameters : bool, optional - If True, return additional transformation parameters, by default False. + return_parameters : bool, default=False + If True, return additional transformation parameters (see below). Returns ------- - np.ndarray | tuple[np.ndarray, bool, bool, bool] - If get_parameters is False: + contour : p.ndarray Transformed contour coordinates. - If get_parameters is True: - - anterior_reversed : bool - Whether anterior axis was reversed. - - superior_reversed : bool - Whether superior axis was reversed. - - swap_axes : bool - Whether axes were swapped. - + anterior_reversed : bool + Only if return_parameters is True, whether anterior axis was reversed. + superior_reversed : bool + Only if return_parameters is True, whether superior axis was reversed. + swap_axes : bool + Only if return_parameters is True, whether axes were swapped. """ # converting to AS (no left-right dimension), out of plane movement is ignores, # so we only do scaling, axes swapping and flipping - no rotation @@ -89,8 +98,8 @@ def convert_to_ras( contour = contour[[1, 0]] # determine if axis were reversed - superior_reversed = (axis_swaps[2, :] == -1).any() - anterior_reversed = (axis_swaps[1, :] == -1).any() + superior_reversed = np.any(axis_swaps[2, :] == -1) + anterior_reversed = np.any(axis_swaps[1, :] == -1) # flip axes if necessary if superior_reversed: @@ -104,18 +113,20 @@ def convert_to_ras( # voxel * vox_size = mm contour = (contour.T * scaling[1:]).T - if get_parameters: + if return_parameters: return contour, anterior_reversed, superior_reversed, swap_axes else: return contour - # # Add a third dimension (z) with 0 and a fourth dimension (homogeneous coordinate) with 1 + # Add a third dimension (z) with 0 and a fourth dimension (homogeneous coordinate) with 1 elif contour.shape[0] == 3: contour_homogeneous = np.vstack([contour, np.ones(contour.shape[1])]) # Apply the transformation contour = (vox2ras_matrix @ contour_homogeneous)[:3, :] return contour + else: + raise ValueError("Invalid shape of contour") def set_contour_zero_idx(contour, idx, anterior_endpoint_idx, posterior_endpoint_idx): @@ -134,14 +145,13 @@ def set_contour_zero_idx(contour, idx, anterior_endpoint_idx, posterior_endpoint Returns ------- - tuple - - contour : np.ndarray - Rolled contour points. - - anterior_endpoint_idx : int - Updated anterior endpoint index. - - posterior_endpoint_idx : int - Updated posterior endpoint index. - """ + contour : np.ndarray + Rolled contour points. + anterior_endpoint_idx : int + Updated anterior endpoint index. + posterior_endpoint_idx : int + Updated posterior endpoint index. +""" contour = np.roll(contour, -idx, axis=0) anterior_endpoint_idx = (anterior_endpoint_idx - idx) % contour.shape[0] posterior_endpoint_idx = (posterior_endpoint_idx - idx) % contour.shape[0] @@ -316,9 +326,10 @@ def cc_thickness( Returns ------- - tuple[np.ndarray, np.ndarray] - - thickness_values : Array of thickness measurements. - - measurement_points : Array of points where thickness was measured. + thickness_values : np.ndarray + Array of thickness measurements. + measurement_points : np.ndarray + Array of points where thickness was measured. Notes ----- diff --git a/CorpusCallosum/transforms/localization_transforms.py b/CorpusCallosum/transforms/localization_transforms.py index f7fa3f0f..12e1e5ac 100644 --- a/CorpusCallosum/transforms/localization_transforms.py +++ b/CorpusCallosum/transforms/localization_transforms.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from logging import getLogger + import numpy as np +import torch from monai.transforms import MapTransform, RandomizableTransform @@ -30,8 +33,8 @@ class CropAroundACPCFixedSize(RandomizableTransform, MapTransform): Fixed size of the crop window (width, height) allow_missing_keys : bool, optional Whether to allow missing keys in the data dictionary, by default False - random_translate : float, optional - Maximum random translation in voxels, by default 0 + random_translate : int, default=0 + Maximum random translation in voxels. Notes ----- @@ -49,30 +52,35 @@ class CropAroundACPCFixedSize(RandomizableTransform, MapTransform): If the crop boundaries extend outside the image dimensions """ - def __init__(self, keys: list[str], fixed_size: tuple[int, int], - allow_missing_keys: bool = False, random_translate: float = 0) -> None: + def __init__( + self, + keys: list[str], + fixed_size: tuple[int, int], + allow_missing_keys: bool = False, + random_translate: int = 0, + ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self) self.random_translate = random_translate self.fixed_size = fixed_size def __call__(self, data: dict) -> dict: - """Apply the transform to the data. + """Apply the 2D crop transform to the data. Parameters ---------- data : dict - Dictionary containing the data to transform + Dictionary containing the data to transform AND keys AC_center and PC_center, each of shape (B, 2). Returns ------- dict Transformed data dictionary with cropped images and updated coordinates. Also includes crop boundary information: - - crop_left : int - - crop_right : int - - crop_top : int - - crop_bottom : int + - crop_left : list[int] + - crop_right : list[int] + - crop_top : list[int] + - crop_bottom : list[int] Raises ------ @@ -81,54 +89,63 @@ def __call__(self, data: dict) -> dict: """ d = dict(data) - for key in self.keys: - if key not in d.keys() and self.allow_missing_keys: - continue + expected_keys = {"PC_center", "AC_center"} | set(self.keys) if not self.allow_missing_keys else {} + + if expected_keys & set(d.keys()) != expected_keys: + raise ValueError(f"The following keys are missing in the data dictionary: {expected_keys - set(d.keys())}!") + + if any(d[k].ndim != 2 or d[k].shape[1] != 2 for k in ["PC_center", "AC_center"]): + raise ValueError("Shape of AC_center or PC_center incorrect, must be (B, 2)!") + + if any(d[k].ndim != 4 for k in self.keys if k in d.keys()): + raise ValueError(f"At least one key of {self.keys} does not have a 4-dimensional tensor.") - # Get AC and PC centers - pc_center = d['PC_center'] - ac_center = d['AC_center'] - # calculate center point between AC and PC - center_point = ((ac_center + pc_center) / 2).astype(int) + center_point = ((d['AC_center'] + d['PC_center']) / 2).astype(int) # Calculate voxel padding based on mm padding voxel_padding = np.asarray(self.fixed_size) // 2 + existing_keys = set(self.keys) & set(d.keys()) + if len(existing_keys) == 0: + getLogger(__name__).warning(f"None of the keys in {self.keys} are present in the data dictionary!") + return d + + first_key = tuple(existing_keys)[0] + + # Calculate crop boundaries with padding and random translation + crops = center_point - voxel_padding + # Add random translation if specified if self.random_translate > 0: - random_translate = np.random.randint(-self.random_translate, - self.random_translate, size=2) - else: - random_translate = np.asarray((0, 0)) + crops += np.random.randint( + -self.random_translate, + self.random_translate + 1, + size=(d[first_key].shape[0], 2), + ) - # Calculate crop boundaries with padding and random translation - crops = center_point - voxel_padding + random_translate - # Ensure crop boundaries are within image - img_shape = np.asarray(d['image'].shape[2:]) # Get spatial dimensions - crops = np.maximum(0, np.minimum(img_shape, crops + np.asarray(self.fixed_size)) - np.asarray(self.fixed_size)) - crop_left, crop_top = crops.tolist() - crop_right, crop_bottom = (crops + np.asarray(self.fixed_size)).tolist() + img_shape = np.asarray(d[first_key].shape[2:]) # Get spatial dimensions + if any(np.any(img_shape != d[k].shape[2:]) for k in self.keys if k in d.keys()): + raise ValueError(f"At least one key of {self.keys} does not have the expected shape.") + + patch_size_with_batch_dim = np.asarray(self.fixed_size)[None] + crops = np.maximum(0, np.minimum(img_shape, crops + patch_size_with_batch_dim) - patch_size_with_batch_dim) + d["crop_left"], d["crop_top"] = crops.T.tolist() + d["crop_right"], d["crop_bottom"] = (crops_end := crops + patch_size_with_batch_dim).T.tolist() # raise error if crop boundaries are out of image - if crop_left < 0 or crop_right > d['image'].shape[2] or crop_top < 0 or crop_bottom > d['image'].shape[3]: + if np.any(crops < 0) or np.any(crops_end > np.asarray([d[first_key].shape[2:]])): raise ValueError("Crop boundaries are out of image") # Apply crop to image for key in self.keys: if key not in d.keys() and self.allow_missing_keys: continue - - d[key] = d[key][:, :, crop_left:crop_right, crop_top:crop_bottom] - - # Update point coordinates relative to cropped image - d['PC_center'][1:] = d['PC_center'][1:] - [crop_left, crop_top] - d['AC_center'][1:] = d['AC_center'][1:] - [crop_left, crop_top] - - - d['crop_left'] = crop_left - d['crop_right'] = crop_right - d['crop_top'] = crop_top - d['crop_bottom'] = crop_bottom + arr = [v[:, cl:cr, ct:cb] for v, cl, ct, cr, cb in zip(d[key], *crops.T, *crops_end.T, strict=True)] + d[key] = torch.stack(arr, dim=0) if torch.is_tensor(arr[0]) else np.stack(arr, axis=0) + + # Update point coordinates relative to cropped image + d["PC_center"] = d["PC_center"] - crops + d["AC_center"] = d["AC_center"] - crops return d diff --git a/recon_surf/align_points.py b/recon_surf/align_points.py index e2ed0e5b..c69446ee 100755 --- a/recon_surf/align_points.py +++ b/recon_surf/align_points.py @@ -127,7 +127,7 @@ def find_rotation(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: return R -def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray, verbose: bool = False) -> np.ndarray: +def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray, verbose: bool = False) -> npt.NDArray[float]: """ Find rigid transformation matrix between two point sets. @@ -142,7 +142,7 @@ def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray, verbose: bool = False) -> Returns ------- - T + np.ndarray Homogeneous transformation matrix. """ if p_mov.shape != p_dst.shape: @@ -160,9 +160,9 @@ def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray, verbose: bool = False) -> t = centroid_dst.T - np.dot(R, centroid_mov.T) # homogeneous transformation m = p_mov.shape[1] - T = np.identity(m + 1) - T[:m, :m] = R - T[:m, m] = t + rigid_transform = np.identity(m + 1, dtype=float) + rigid_transform[:m, :m] = R + rigid_transform[:m, m] = t # compute disteances if verbose: dd = p_mov - p_dst @@ -170,7 +170,7 @@ def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray, verbose: bool = False) -> dd = (np.transpose(R @ np.transpose(p_mov)) + t) - p_dst print(f"Final avg SSD: {np.sum(dd * dd) / p_mov.shape[0]}") # return T, R, t - return T + return rigid_transform def find_affine(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: """ From f2aa773ffbf459dc5539eb334eae99c7d77b19c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Tue, 25 Nov 2025 14:37:50 +0100 Subject: [PATCH 29/68] rename files and standardize file names change the default settings of various output files to lighter default settings while keeping the ability to have those outputs, e.g. for qc output. rename, standardize and improve "meaningfulness" of variable names Update typing and docstrings Remove unwanted extra empty lines and debugging plots (comments) Use suppress_stdout from FastSurfer (which uses redirect_stdout) instead of the new implementation HidePrints add and optimize parallelization across multiple files (halfs processing time of fastsurfer-cc.py) move some constant values from various files into constants or just use the constants that were already available fix several functions broken by myself in review or other commits make plotly html plots use cdn-delivered javascript to significantly reduce he size of files Remove redundant function parts in recon_cc_surf_measures_multi Reorder arguments in recon_cc_surf_measure Integrate FastSurferCC into run_fastsurfer, create stats files of updated aseg and asegdkt Update on-demand execution of FastSurferCC in recon-surf Add additional option to also reduce asegdkt to aseg in paint_cc_into_pred --- CorpusCallosum/cc_visualization.py | 2 +- CorpusCallosum/data/constants.py | 34 +- CorpusCallosum/data/read_write.py | 27 +- CorpusCallosum/fastsurfer_cc.py | 583 ++++++++++-------- CorpusCallosum/paint_cc_into_pred.py | 224 ++++--- .../registration/mapping_helpers.py | 77 +-- .../segmentation/segmentation_inference.py | 4 +- CorpusCallosum/shape/cc_mesh.py | 13 +- CorpusCallosum/shape/cc_postprocessing.py | 302 ++++----- CorpusCallosum/shape/cc_thickness.py | 198 +----- CorpusCallosum/utils/utils.py | 60 -- CorpusCallosum/visualization/visualization.py | 49 +- recon_surf/recon-surf.sh | 81 +-- run_fastsurfer.sh | 110 +++- 14 files changed, 870 insertions(+), 894 deletions(-) delete mode 100644 CorpusCallosum/utils/utils.py diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index c1df96f8..8d9aa856 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -168,7 +168,7 @@ def main( measurement_points_path=options.measurement_points, output_dir=options.output_dir, resolution=options.resolution, - smooth_iterations=options.smooth_iterations, + smoothing_window=options.smoothing_window, colormap=options.colormap, color_range=options.color_range, legend=options.legend, diff --git a/CorpusCallosum/data/constants.py b/CorpusCallosum/data/constants.py index 321bd8da..8087079b 100644 --- a/CorpusCallosum/data/constants.py +++ b/CorpusCallosum/data/constants.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - - from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT ### Constants @@ -34,25 +32,25 @@ STANDARD_OUTPUT_PATHS = { ## images - "upright_volume": None, # orig.mgz mapped to upright space + "upright_volume": None, # orig.mgz mapped to upright space ## segmentations - "segmentation": "mri/callosum_seg_upright.mgz", # corpus callosum segmentation in upright space - "orig_space_segmentation": "mri/callosum_seg_aseg_space.mgz", # cc segmentation in input segmentations space - "softlabels_cc": "mri/callosum_seg_soft.mgz", # cc softlabels in upright space - "softlabels_fn": "mri/fornix_seg_soft.mgz", # fornix softlabels in upright space - "softlabels_background": "mri/background_seg_soft.mgz", # background softlabels in upright space + "segmentation": "mri/callosum.CC.upright.mgz", # corpus callosum segmentation in upright space + "segmentation_in_orig": "mri/callosum.CC.orig.mgz", # cc segmentation in input segmentations space + "softlabels_cc": "mri/callosum.CC.soft.mgz", # cc softlabels in upright space + "softlabels_fn": "mri/fornix.CC.soft.mgz", # fornix softlabels in upright space + "softlabels_background": "mri/background.CC.soft.mgz", # background softlabels in upright space ## stats - "cc_markers": "stats/callosum.CC.midslice.json", # cc metrics for middle slice - "postproc_results": "stats/callosum.CC.all_slices.json", # cc metrics for all slices + "cc_markers": "stats/callosum.CC.midslice.json", # cc metrics for middle slice + "cc_measures": "stats/callosum.CC.all_slices.json", # cc metrics for all slices ## transforms - "upright_lta": "mri/transforms/cc_up.lta", # lta transform from orig to upright space - "orient_volume_lta": "mri/transforms/orient_volume.lta", # lta transform from orig to upright+acpc corrected space + "upright_lta": "mri/transforms/cc_up.lta", # lta transform from orig to upright space + "orient_volume_lta": "mri/transforms/orient_volume.lta", # lta transform from orig to upright+acpc corrected space ## qc - "qc_image": "qc_snapshots/callosum.png", # debug image of cc contours - "thickness_image": "qc_snapshots/callosum_thickness.png", # whippersnappy 3D image of cc thickness - "cc_html": "qc_snapshots/corpus_callosum.html", # plotly cc visualization + "qc_image": "{qc_output_dir}/callosum.png", # debug image of cc contours + "thickness_image": "{qc_output_dir}/callosum.thickness.png", # whippersnappy 3D image of cc thickness + "cc_html": "{qc_output_dir}/corpus_callosum.html", # plotly cc visualization ## surface - "surf_file": "surf/callosum.surf", # cc surface file - "overlay_file": "surf/callosum.thickness.w", # cc surface overlay file - "vtk_file": "surf/callosum_mesh.vtk", # vtk file of cc mesh + "cc_surf": "surf/callosum.surf", # cc surface file + "cc_thickness_overlay": "surf/callosum.thickness.w", # cc surface overlay file + "cc_surf_vtk": "surf/callosum.vtk", # vtk file of cc mesh } \ No newline at end of file diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py index 01b354d7..503b5b12 100644 --- a/CorpusCallosum/data/read_write.py +++ b/CorpusCallosum/data/read_write.py @@ -21,6 +21,7 @@ from numpy import typing as npt import FastSurferCNN.utils.logging as logging +from FastSurferCNN.utils.parallel import thread_executor class FSAverageHeader(TypedDict): @@ -59,23 +60,21 @@ def get_centroids_from_nib(seg_img: nib.analyze.SpatialImage, label_ids: list[in else: labels = label_ids - def _calc_ras_centroid(mask_vox: npt.NDArray[np.integer]) -> npt.NDArray[float]: - # Calculate centroid in voxel space - vox_centroid = np.mean(mask_vox, axis=1, dtype=float) + def _each_label(label): + # Get voxel indices for this label + if np.any(mask := seg_data == label): + # Calculate centroid in voxel space + vox_centroid = np.mean(np.where(mask), axis=1, dtype=float) - # Convert to homogeneous coordinates - vox_centroid = np.append(vox_centroid, 1) + # Convert to homogeneous coordinates + vox_centroid_hom = np.append(vox_centroid, 1) - # Transform to RAS coordinates and return without homogeneous coordinate - return (vox2ras @ vox_centroid)[:3] + # Transform to RAS coordinates and return without homogeneous coordinate + return int(label), (vox2ras @ vox_centroid_hom)[:3] + else: + return int(label), None - centroids = {} - for label in labels: - # Get voxel indices for this label - vox_coords = np.array(np.where(seg_data == label)) - centroids[int(label)] = None if vox_coords.size == 0 else _calc_ras_centroid(vox_coords) - - return centroids + return dict(thread_executor().map(_each_label, labels)) def convert_numpy_to_json_serializable(obj: object) -> object: diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index de458eea..eccdc1b6 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -15,8 +15,10 @@ import argparse import json +from collections.abc import Iterable from pathlib import Path -from typing import Literal, cast +from time import perf_counter_ns +from typing import Literal, TypeVar, cast import nibabel as nib import numpy as np @@ -49,31 +51,117 @@ ) from CorpusCallosum.segmentation import segmentation_inference, segmentation_postprocessing from CorpusCallosum.shape.cc_postprocessing import ( + SliceSelection, SubdivisionMethod, check_area_changes, make_subdivision_mask, recon_cc_surf_measures_multi, ) from FastSurferCNN.data_loader.conform import is_conform +from FastSurferCNN.segstats import HelpFormatter from FastSurferCNN.utils import logging +from FastSurferCNN.utils.arg_types import path_or_none from FastSurferCNN.utils.common import SubjectDirectory, find_device -from FastSurferCNN.utils.common import thread_executor as executor -from recon_surf import lta +from FastSurferCNN.utils.parallel import shutdown_executors, thread_executor +from FastSurferCNN.utils.parser_defaults import modify_argument from recon_surf.align_points import find_rigid +from recon_surf.lta import write_lta logger = logging.get_logger(__name__) -SliceSelection = Literal["middle", "all"] | int +_TPathLike = TypeVar("_TPathLike", str, Path, Literal[None]) + + + +class ReplaceQCOutputDir(Path): + """ + A helper class to validate `qc_output_dir` dependent paths. + + Replaces {qc_output_dir} at the start of filename with the correct qc_output_dir. + Also returns None, if qc_output_dir was None. + """ + + def __init__(self, a: Path | str | None): + if a is None: + a = "{None}" + if "{qc_output_dir}" in str(a).removeprefix("{qc_output_dir}/"): + raise ValueError("If the argument contains {qc_output_dir}, it must start with '{qc_output_dir}/'!") + super().__init__(a) + + def replace_qc_dir(self, qc_output_dir: _TPathLike) -> Path | None: + """ + Helper function to replace {qc_output_dir} at the start of filename with the correct qc_output_dir. + + Also returns None, if qc_output_dir was None. + + Notes + ----- + This function implements + """ + if str(self) == "{None}": + return None + elif "{qc_output_dir}" not in str(self): + return self + elif qc_output_dir is None: + return None + + return Path(str(self).replace("{qc_output_dir}", str(qc_output_dir))) + + +class ArgumentDefaultsHelpFormatter(HelpFormatter): + """Help message formatter which adds default values to argument help.""" + + def _get_help_string(self, action): + """ + Add the default value to the option help message. + """ + help = action.help + if help is None: + help = '' + + if "%(default)" not in help and not getattr(action, "required", False): + if action.default is not argparse.SUPPRESS and not getattr(action.default, "DO_NOT_PRINT_DEFAULT", False): + defaulting_nargs = [argparse.OPTIONAL, argparse.ZERO_OR_MORE] + if action.option_strings or action.nargs in defaulting_nargs: + help += " (not used by default)" if action.default is None else " (default: %(default)s)" + return help + + +class _FixFloatFormattingList(list): + def __init__(self, items: Iterable, item_format_spec: str): + self._format_spec = item_format_spec + super().__init__(items) + + def __str__(self): + return "[" + ", ".join(map(lambda x: format(x, self._format_spec), self)) + "]" + + +def _do_not_print(value): + class _DoNotPrintGeneric(type(value)): + DO_NOT_PRINT_DEFAULT = True + + return _DoNotPrintGeneric(value) def make_parser() -> argparse.ArgumentParser: """Create the argument parse object for the pipeline.""" from FastSurferCNN.utils.parser_defaults import add_arguments - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument( + "-v", + "--verbose", + action="count", + default=_do_not_print(0), + help="Enable verbose (pass twice for debug-output).", + ) # Specify subject directory + subject ID, OR specify individual MRI and segmentation files + output paths add_arguments(parser, ["sd", "sid", "conformed_name", "aseg_name", "device"]) + def _set_help_sid(action): + action.help = "The subject id to use." + modify_argument(parser, "--sid", _set_help_sid) + parser.add_argument( "--num_thickness_points", type=int, @@ -85,25 +173,26 @@ def make_parser() -> argparse.ArgumentParser: type=float, metavar="FRAC", nargs=4, - default=[1/6, 1/2, 2/3, 3/4], + default=_FixFloatFormattingList([1 / 6, 1 / 2, 2 / 3, 3 / 4], ".3f"), help="List of FOUR subdivision fractions for the corpus callosum subsegmentation.", ) parser.add_argument( "--subdivision_method", - default="shape", - help="Method for contour subdivision. \ - Options: shape (Intercallosal subdivision perpendicular to intercallosal line), vertical \ - (orthogonal to the most anterior and posterior points in the AC/PC standardized CC contour), \ - angular (subdivision based on equally spaced angles, as proposed by Hampel and colleagues), \ - eigenvector (primary direction, same as FreeSurfers mri_cc)", + default=_do_not_print("shape"), + help="Method for contour subdivision. Options:
" + "- shape (default): Intercallosal subdivision perpendicular to intercallosal line,
" + "- vertical: orthogonal to the most anterior and posterior points in the AC/PC standardized CC contour, " + "
" + "- angular: subdivision based on equally spaced angles, as proposed by Hampel and colleagues,
" + "- eigenvector: primary direction, same as FreeSurfers mri_cc.", choices=["shape", "vertical", "angular", "eigenvector"], ) parser.add_argument( "--contour_smoothing", type=float, default=5, - help="Gaussian sigma for smoothing during contour detection. Higher values mean a smoother" - " CC outline, at the cost of precision. (default: 5)", + help="Gaussian sigma for smoothing during contour detection. Higher values mean a smoother CC outline, at the " + "cost of precision.", ) def _slice_selection(a: str) -> SliceSelection: if a.lower() in ("middle", "all"): @@ -112,140 +201,135 @@ def _slice_selection(a: str) -> SliceSelection: parser.add_argument( "--slice_selection", type=_slice_selection, - default="all", - help="Which slices to process. Options: 'middle', 'all', or a specific slice number. \ - (default: 'all')", - ) - parser.add_argument( - "-v", - "--verbose", - action="count", - default=0, - help="Enable verbose (pass twice for debug-output).", + default=_do_not_print("all"), + help="Which slices to process. Options: 'middle', 'all' (default), or a specific slice number.", ) ######## OUTPUT PATHS ######### # 4. Options for advanced, technical parameters advanced = parser.add_argument_group( title="Advanced options", - description="Custom output paths, useful if no standard case directory is used.", + description="Custom output paths, useful if no standard case directory is used. Relative paths are always " + "relative to the subject_dir defined via --sd and --sid!", ) - advanced.add_argument("--qc_output_dir", - type=Path, + add_arguments(advanced, ["threads"]) + advanced.add_argument( + "--qc_output_dir", + type=path_or_none, required=False, default=None, - help="Directory for quality control output (default: subject_dir/qc_snapshots)") + help="Enables quality control output (paths starting with {qc_output_dir} by default) and sets {qc_output_dir} " + "(the FastSurfer standard is 'qc_snapshots' to save these files in subject_dir/qc_snapshots).", + ) advanced.add_argument( - "--upright_volume_path", - type=Path, - help="Path for upright volume output (default: No output)", + "--upright_volume", + type=path_or_none, + help="Path for upright volume output.", default=None, ) advanced.add_argument( - "--seg", - dest="segmentation_path", - type=Path, - help=f"Path for segmentation output (default: subject_dir/{STANDARD_OUTPUT_PATHS['segmentation']})", - default=None, + "--segmentation", "--seg", + type=path_or_none, + help="Path for corpus callosum and fornix segmentation 3D image.", + default=Path(STANDARD_OUTPUT_PATHS["segmentation"]), ) advanced.add_argument( - "--postproc_results_path", - type=Path, - help=f"Path for postprocessing results. Contains metrics describing CC shape and volume for each slice \ - (default: subject_dir/{STANDARD_OUTPUT_PATHS['postproc_results']})", - default=None, + "--cc_measures", + type=path_or_none, + help="Path for surface-based corpus callosum measures describing shape and volume for each image slice.", + default=Path(STANDARD_OUTPUT_PATHS["cc_measures"]), ) advanced.add_argument( - "--cc_markers_path", - type=Path, - help=f"Path for CC markers output. Contains metrics describing CC shape and volume \ - (default: subject_dir/{STANDARD_OUTPUT_PATHS['cc_markers']})", - default=None, + "--cc_mid_measures", + type=path_or_none, + help="Path for surface-based corpus callosum measures of the midslice describing CC shape and volume.", + default=STANDARD_OUTPUT_PATHS["cc_markers"], ) advanced.add_argument( - "--upright_lta_path", - type=Path, - help=f"Path for upright LTA transform. This makes sure the midplane is at 128 in LR direction, but no nodding \ - correction is applied (default: subject_dir/{STANDARD_OUTPUT_PATHS['upright_lta']})", - default=None, + "--upright_lta", + type=path_or_none, + help="Path for upright LTA transform. This makes sure the midplane is at 128 in LR direction, but no nodding " + "correction is applied.", + default=STANDARD_OUTPUT_PATHS["upright_lta"], ) advanced.add_argument( - "--orient_volume_lta_path", - type=Path, - help=f"Path for orientation volume LTA transform. This makes sure the midplane is at 128 in LR direction, \ - and the AC & PC are on the coordinate line, standardizing the head orientation. \ - (default: subject_dir/{STANDARD_OUTPUT_PATHS['orient_volume_lta']})", - default=None, + "--orient_volume_lta", + type=path_or_none, + help="Path for orientation volume LTA transform. This makes sure the midplane is at 128 in LR direction, and " + "the anterior and posterior commisures are on the coordinate line, standardizing the head orientation.", + default=STANDARD_OUTPUT_PATHS["orient_volume_lta"], ) advanced.add_argument( - "--orig_space_segmentation_path", - type=Path, - help="Path for segmentation in the input MRI space " - f"(default: subject_dir/{STANDARD_OUTPUT_PATHS['orig_space_segmentation']})", - default=None, + "--segmentation_in_orig", + type=path_or_none, + help="Path for corpus callosum and fornix segmentation in the input MRI space.", + default=STANDARD_OUTPUT_PATHS["segmentation_in_orig"], ) advanced.add_argument( - "--qc_image_path", - type=Path, - help=f"Path for QC visualization image (default: subject_dir/{STANDARD_OUTPUT_PATHS['qc_image']})", - default=None, + "--qc_image", + type=ReplaceQCOutputDir, + help="Path for QC visualization image (if it starts with {qc_output_dir}, that is replace by --qc_output_dir).", + default=STANDARD_OUTPUT_PATHS["qc_image"], ) advanced.add_argument( "--save_template_dir", - type=Path, - help="Directory path where to save contours.txt and thickness_values.txt files. \ - These files can be used to visualize the CC shape and volume in 3D.", + type=path_or_none, + help="Directory path where to save contours.txt and thickness_values.txt files. These files can be used to " + "visualize the CC shape and volume in 3D.", default=None, ) advanced.add_argument( - "--thickness_image_path", - type=Path, - help=f"Path for thickness image (default: subject_dir/{STANDARD_OUTPUT_PATHS['thickness_image']})", - default=None, + "--thickness_image", + type=ReplaceQCOutputDir, + help="Path for thickness image (if it starts with {qc_output_dir}, that is replace by --qc_output_dir).", + default=STANDARD_OUTPUT_PATHS["thickness_image"], ) advanced.add_argument( - "--surf_file_path", - type=Path, - help=f"Path for surf file (default: subject_dir/{STANDARD_OUTPUT_PATHS['surf_file']})", - default=None, + "--surf", + dest="cc_surf", + type=path_or_none, + help="Path for surf file.", + default=STANDARD_OUTPUT_PATHS["cc_surf"], ) advanced.add_argument( - "--overlay_file_path", - type=Path, - help=f"Path for overlay file (default: subject_dir/{STANDARD_OUTPUT_PATHS['overlay_file']})", - default=None, + "--thickness_overlay", + type=path_or_none, + help="Path for corpus callosum thickness overlay file.", + default=STANDARD_OUTPUT_PATHS["cc_thickness_overlay"], ) advanced.add_argument( - "--cc_html_path", - type=Path, - help=f"Path to CC 3D visualization for CC HTML file (default: subject_dir/{STANDARD_OUTPUT_PATHS['cc_html']})", - default=None, + "--cc_interactive_html", "--cc_html", + dest="cc_html", + type=ReplaceQCOutputDir, + help="Path to the corpus callosum interactive 3D visualization HTML file (if it starts with {qc_output_dir}, " + "that is replace by --qc_output_dir).", + default=STANDARD_OUTPUT_PATHS["cc_html"], ) advanced.add_argument( - "--vtk_file_path", - type=Path, - help=f"Path for vtk file, showing the CC 3D mesh (default: subject_dir/{STANDARD_OUTPUT_PATHS['vtk_file']})", + "--cc_surf_vtk", + type=path_or_none, + help=f"Path for vtk file, showing the CC 3D mesh. Example: {STANDARD_OUTPUT_PATHS['cc_surf_vtk']}.", default=None, ) advanced.add_argument( - "--softlabels_cc_path", - type=Path, - help=f"Path for cc softlabels. Contains the probability of each voxel being part of the CC \ - (default: subject_dir/{STANDARD_OUTPUT_PATHS['softlabels_cc']})", + "--softlabels_cc", + type=path_or_none, + help=f"Path for corpus callosum softlabels, which contains the soft labels of each voxel. " + f"Example: {STANDARD_OUTPUT_PATHS['softlabels_cc']}.", default=None, ) advanced.add_argument( - "--softlabels_fn_path", - type=Path, - help=f"Path for fornix softlabels. Contains the probability of each voxel being part of the Fornix \ - (default: subject_dir/{STANDARD_OUTPUT_PATHS['softlabels_fn']})", + "--softlabels_fn", + type=path_or_none, + help=f"Path for fornix softlabels, which contains the soft labels of each voxel. " + f"Example: {STANDARD_OUTPUT_PATHS['softlabels_fn']}.", default=None, ) advanced.add_argument( - "--softlabels_background_path", - type=Path, - help=f"Path for background softlabels. Contains the probability of each voxel being part of the background \ - (default: subject_dir/{STANDARD_OUTPUT_PATHS['softlabels_background']})", + "--softlabels_background", + type=path_or_none, + help=f"Path for background softlabels, which contains the probability of each voxel. " + f"Example: {STANDARD_OUTPUT_PATHS['softlabels_background']}.", default=None, ) ############ END OF OUTPUT PATHS ############ @@ -258,7 +342,7 @@ def options_parse() -> argparse.Namespace: args = parser.parse_args() # Reconstruct subject_dir from sd and sid (but sd might be stored as out_dir by parser_defaults) - sd_value = getattr(args, 'sd', getattr(args, 'out_dir', None)) + sd_value = getattr(args, 'out_dir', None) if sd_value and hasattr(args, 'sid') and args.sid: args.subject_dir = Path(sd_value) / args.sid else: @@ -286,33 +370,25 @@ def options_parse() -> argparse.Namespace: # If subject_dir is provided, set default paths for missing arguments if args.subject_dir: - subject_dir_path = args.subject_dir - # Create standard FreeSurfer subdirectories - (subject_dir_path / "mri").mkdir(parents=True, exist_ok=True) - (subject_dir_path / "stats").mkdir(parents=True, exist_ok=True) - (subject_dir_path / "transforms").mkdir(parents=True, exist_ok=True) - if not args.conf_name: - args.conf_name = str(subject_dir_path / STANDARD_INPUT_PATHS["conf_name"]) + args.conf_name = args.subject_dir / STANDARD_INPUT_PATHS["conf_name"] if not args.aseg_name: - args.aseg_name = str(subject_dir_path / STANDARD_INPUT_PATHS["aseg_name"]) - - # Set default output paths if not provided - for key, value in STANDARD_OUTPUT_PATHS.items(): - if not getattr(args, f"{key}_path") and value is not None: - setattr(args, f"{key}_path", subject_dir_path / value) + args.aseg_name = args.subject_dir / STANDARD_INPUT_PATHS["aseg_name"] - # Set output_dir to subject_dir - args.output_dir = str(subject_dir_path) + all_paths = ("segmentation", "segmentation_in_orig", "cc_measures", "upright_lta", "orient_volume_lta", "cc_surf", + "softlabels_cc", "softlabels_fn", "softlabels_background", "cc_mid_measures", "cc_thickness_overlay", + "qc_image", "thickness_image", "cc_html") # Create parent directories for all output paths - for path_name in STANDARD_OUTPUT_PATHS.keys(): - path = getattr(args, f"{path_name}_path") - if path is not None: - Path(path).parent.mkdir(parents=False, exist_ok=True) - + for path_name in all_paths: + path: ReplaceQCOutputDir | Path | None = getattr(args, path_name, None) + if isinstance(path, ReplaceQCOutputDir): + path = path.replace_qc_dir(getattr(args, "qc_output_dir", None)) + if isinstance(path, Path) and not args.subject_dir and not path.is_absolute(): + parser.error(f"Must specify --sd and --sid if any path is relative but {path} for {path_name} is relative.") + setattr(args, path_name, path) return args @@ -352,7 +428,7 @@ def centroid_registration(aseg_nib: nib.analyze.SpatialImage) -> tuple[ # Load pre-computed fsaverage centroids and data from static files centroids_dst = load_fsaverage_centroids(FSAVERAGE_CENTROIDS_PATH) - fsaverage_affine, fsaverage_header, vox2ras_tkr = load_fsaverage_data(FSAVERAGE_DATA_PATH) + fsaverage_data_future = thread_executor().submit(load_fsaverage_data, FSAVERAGE_DATA_PATH) centroids_mov = get_centroids_from_nib(aseg_nib, label_ids=list(centroids_dst.keys())) @@ -367,9 +443,9 @@ def centroid_registration(aseg_nib: nib.analyze.SpatialImage) -> tuple[ # make affine that increases resolution to orig resolution resolution_trans: npt.NDArray[float] = np.diagflat(list(aseg_nib.header.get_zooms()[:3]) + [1]).astype(float) - orig_fsaverage_vox2vox: npt.NDArray[float] = ( - np.linalg.inv(resolution_trans @ fsaverage_affine) @ orig_fsaverage_ras2ras @ aseg_nib.affine - ) + fsaverage_affine, fsaverage_header, vox2ras_tkr = fsaverage_data_future.result() + _highres_fsaverage: npt.NDArray[float] = np.linalg.inv(resolution_trans @ fsaverage_affine) + orig_fsaverage_vox2vox: npt.NDArray[float] = _highres_fsaverage @ orig_fsaverage_ras2ras @ aseg_nib.affine fsaverage_hires_affine: npt.NDArray[float] = resolution_trans @ fsaverage_affine logger.info("Centroid registration successful!") return orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, fsaverage_hires_affine, fsaverage_header, vox2ras_tkr @@ -488,7 +564,6 @@ def main( aseg_name: str | Path, subject_dir: str | Path, slice_selection: SliceSelection = "middle", - #TODO: qc_output_dir is currently unused ?! qc_output_dir: str | Path | None = None, num_thickness_points: int = 100, subdivisions: list[float] | None = None, @@ -496,22 +571,22 @@ def main( contour_smoothing: float = 5, save_template_dir: str | Path | None = None, device: str | torch.device = "auto", - upright_volume_path: str | Path | None = None, - segmentation_path: str | Path | None = None, - postproc_results_path: str | Path | None = None, - cc_markers_path: str | Path | None = None, - upright_lta_path: str | Path | None = None, - orient_volume_lta_path: str | Path | None = None, - surf_file_path: str | Path | None = None, - overlay_file_path: str | Path | None = None, - cc_html_path: str | Path | None = None, - vtk_file_path: str | Path | None = None, - orig_space_segmentation_path: str | Path | None = None, - qc_image_path: str | Path | None = None, - thickness_image_path: str | Path | None = None, - softlabels_cc_path: str | Path | None = None, - softlabels_fn_path: str | Path | None = None, - softlabels_background_path: str | Path | None = None, + upright_volume: str | Path | None = None, + segmentation: str | Path | None = None, + cc_measures: str | Path | None = None, + cc_mid_measures: str | Path | None = None, + upright_lta: str | Path | None = None, + orient_volume_lta: str | Path | None = None, + cc_surf: str | Path | None = None, + cc_thickness_overlay: str | Path | None = None, + cc_html: str | Path | None = None, + cc_surf_vtk: str | Path | None = None, + segmentation_in_orig: str | Path | None = None, + qc_image: str | Path | None = None, + thickness_image: str | Path | None = None, + softlabels_cc: str | Path | None = None, + softlabels_fn: str | Path | None = None, + softlabels_background: str | Path | None = None, ) -> None: """Main pipeline function for corpus callosum analysis. @@ -529,7 +604,7 @@ def main( slice_selection : "middle", "all" or int, default="middle" Which slices to process. qc_output_dir : str or Path, optional - Directory for quality control outputs, None deactivates qc snapshots. + Directory for quality control outputs, activates qc_image, thickness_image, cc_html. num_thickness_points : int, default=100 Number of points for thickness estimation. subdivisions : list[float], optional @@ -543,37 +618,37 @@ def main( the CC shape and volume in 3D. Files are only saved, if a valid directory path is passed. device : str, default="auto" Device to run inference on ('auto', 'cpu', 'cuda', or 'cuda:X'). - upright_volume_path : str or Path, optional + upright_volume : str or Path, optional Path to save upright volume. - segmentation_path : str or Path, optional + segmentation : str or Path, optional Path to save segmentation. - postproc_results_path : str or Path, optional + cc_measures : str or Path, optional Path to save post-processing results. - cc_markers_path : str or Path, optional + cc_mid_measures : str or Path, optional Path to save CC markers. - upright_lta_path : str or Path, optional + upright_lta : str or Path, optional Path to save upright LTA transform. - orient_volume_lta_path : str or Path, optional + orient_volume_lta : str or Path, optional Path to save orientation transform. - surf_file_path : str or Path, optional + cc_surf : str or Path, optional Path to save surface file. - overlay_file_path : str or Path, optional + cc_thickness_overlay : str or Path, optional Path to save overlay file. - cc_html_path : str or Path, optional + cc_html : str or Path, optional Path to save HTML visualization. - vtk_file_path : str or Path, optional + cc_surf_vtk : str or Path, optional Path to save VTK file. - orig_space_segmentation_path : str or Path, optional + segmentation_in_orig : str or Path, optional Path to save segmentation in original space. - qc_image_path : str or Path, optional + qc_image : str or Path, optional Path to save QC images. - thickness_image_path : str or Path, optional + thickness_image : str or Path, optional Path to save thickness visualization. - softlabels_cc_path : str or Path, optional + softlabels_cc : str or Path, optional Path to save CC soft labels. - softlabels_fn_path : str or Path, optional + softlabels_fn : str or Path, optional Path to save fornix soft labels. - softlabels_background_path : str or Path, optional + softlabels_background : str or Path, optional Path to save background soft labels. Notes @@ -594,6 +669,8 @@ def main( 5. Performs enhanced post-processing analysis. 6. Saves results and visualizations. """ + start = perf_counter_ns() + import sys if subdivisions is None: @@ -605,7 +682,7 @@ def main( logger.info(f"Input MRI: {conf_name}") logger.info(f"Input segmentation: {aseg_name}") logger.info(f"Output directory: {subject_dir}") - + # Convert all paths to Path objects sd = SubjectDirectory( subject_dir.parent, @@ -613,22 +690,22 @@ def main( conf_name=conf_name, aseg_name=aseg_name, save_template_dir=save_template_dir, - upright_volume=upright_volume_path, - cc_segmentation=segmentation_path, - cc_postproc_results=postproc_results_path, - cc_markers=cc_markers_path, - upright_lta=upright_lta_path, - cc_orient_volume_lta=orient_volume_lta_path, - cc_surf=surf_file_path, - cc_overlay=overlay_file_path, - cc_html=cc_html_path, - cc_mesh=vtk_file_path, - cc_orig_segfile=orig_space_segmentation_path, - cc_qc_images=qc_image_path, - cc_thickness_image=thickness_image_path, - cc_softlabels_cc=softlabels_cc_path, - cc_softlabels_fn=softlabels_fn_path, - cc_softlabels_background=softlabels_background_path, + upright_volume=upright_volume, + cc_segmentation=segmentation, + cc_measures=cc_measures, + cc_mid_measures=cc_mid_measures, + upright_lta=upright_lta, + cc_orient_volume_lta=orient_volume_lta, + cc_surf=cc_surf, + cc_thickness_overlay=cc_thickness_overlay, + cc_html=cc_html, + cc_mesh=cc_surf_vtk, + cc_orig_segfile=segmentation_in_orig, + cc_qc_image=qc_image, + cc_thickness_image=thickness_image, + cc_softlabels_cc=softlabels_cc, + cc_softlabels_fn=softlabels_fn, + cc_softlabels_background=softlabels_background, ) # Validate subdivision fractions @@ -639,7 +716,7 @@ def main( #### setup variables io_futures = [] - orig = nib.load(sd.conf_name) + orig = cast(nib.analyze.SpatialImage, nib.load(sd.conf_name)) # 5 mm around the midplane (making sure to get rl by as_closest_canonical) slices_to_analyze = int(np.ceil(5 / nib.as_closest_canonical(orig).header.get_zooms()[0])) // 2 * 2 + 1 @@ -658,7 +735,7 @@ def main( # load models device = find_device(device) logger.info(f"Using device: {device}") - + logger.info("Loading models") model_localization = localization_inference.load_model(device=device) model_segmentation = segmentation_inference.load_model(device=device) @@ -676,7 +753,7 @@ def main( # start saving upright volume if sd.has_attribute("upright_volume"): io_futures.append( - executor().submit( + thread_executor().submit( apply_transform_to_volume, orig, orig2fsavg_vox2vox, @@ -705,7 +782,7 @@ def main( for i, (attr, name) in enumerate((("background",) * 2, ("cc", "Corpus Callosum"), ("fn", "Fornix"))): if sd.has_attribute(f"cc_softlabels_{attr}"): logger.info(f"Saving {name} softlabels to {sd.filename_by_attribute(f'cc_softlabels_{attr}')}") - io_futures.append(executor().submit( + io_futures.append(thread_executor().submit( nib.save, nib.MGHImage(cc_fn_softlabels[..., i], seg_affine, orig.header), sd.filename_by_attribute(f"cc_softlabels_{attr}"), @@ -713,6 +790,7 @@ def main( # Create a temporary segmentation image with proper affine for enhanced postprocessing # Process slices based on selection mode + logger.info(f"Processing slices with selection mode: {slice_selection}") slice_results, slice_io_futures = recon_cc_surf_measures_multi( segmentation=cc_fn_seg_labels, @@ -742,27 +820,26 @@ def main( middle_slice_result = slice_results[len(slice_results) // 2] if len(middle_slice_result['split_contours']) <= 5: - subdivision_mask = make_subdivision_mask( + cc_subseg_midslice = make_subdivision_mask( cc_fn_seg_labels.shape[1:], middle_slice_result['split_contours'], orig.header.get_zooms(), ) else: logger.warning("Too many subsegments for lookup table, skipping sub-divion of output segmentation.") - subdivision_mask = None + cc_subseg_midslice = None - # map soft labels to original space (in parallel because this takes a while) - io_futures.append(executor().submit( + # map soft labels to original space (in parallel because this takes a while, and we only do it to save the labels) + io_futures.append(thread_executor().submit( map_softlabels_to_orig, - outputs_soft=cc_fn_softlabels, + cc_fn_softlabels=cc_fn_softlabels, orig_fsaverage_vox2vox=orig2fsavg_vox2vox, orig=orig, - slices_to_analyze=slices_to_analyze, - orig_space_segmentation_path=orig_space_segmentation_path, + orig_space_segmentation_path=segmentation_in_orig, fsaverage_middle=FSAVERAGE_MIDDLE, - subdivision_mask=subdivision_mask, + cc_subseg_midslice=cc_subseg_midslice, )) - io_futures.append(executor().submit( + io_futures.append(thread_executor().submit( nib.save, nib.MGHImage(cc_fn_seg_labels, seg_affine, orig.header), sd.filename_by_attribute("cc_segmentation"), @@ -795,26 +872,26 @@ def main( additional_metrics = {} if len(outer_contours) > 1: cc_volume_voxel = segmentation_postprocessing.get_cc_volume_voxel( - desired_width_mm=5, + desired_width_mm=5, cc_mask=cc_fn_seg_labels == CC_LABEL, voxel_size=orig.header.get_zooms() ) cc_volume_contour = segmentation_postprocessing.get_cc_volume_contour( - cc_contours=outer_contours, + cc_contours=outer_contours, voxel_size=orig.header.get_zooms() ) logger.info(f"CC volume voxel: {cc_volume_voxel}") logger.info(f"CC volume contour: {cc_volume_contour}") - + additional_metrics["cc_5mm_volume"] = cc_volume_voxel additional_metrics["cc_5mm_volume_pv_corrected"] = cc_volume_contour - + # get ac and pc in all spaces ac_coords_3d = np.hstack((FSAVERAGE_MIDDLE, ac_coords)) pc_coords_3d = np.hstack((FSAVERAGE_MIDDLE, pc_coords)) - standardized_to_orig_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig = ( + standardized2orig_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig = ( calc_mapping_to_standard_space(orig, ac_coords_3d, pc_coords_3d, orig2fsavg_vox2vox) ) @@ -836,45 +913,57 @@ def main( # Convert numpy arrays to lists for JSON serialization output_metrics_middle_slice = convert_numpy_to_json_serializable(output_metrics_middle_slice | additional_metrics) - logger.info(f"Saving CC markers to {sd.filename_by_attribute('cc_markers')}") - with open(sd.filename_by_attribute("cc_markers"), "w") as f: - json.dump(output_metrics_middle_slice, f, indent=4) - - per_slice_output_dict = convert_numpy_to_json_serializable(per_slice_output_dict | additional_metrics) - - # Save slice-wise postprocessing results to JSON - with open(sd.filename_by_attribute("cc_postproc_results"), "w") as f: - json.dump(per_slice_output_dict, f, indent=4) + if sd.has_attribute("cc_mid_measures"): + logger.info(f"Saving CC markers to {sd.filename_by_attribute('cc_mid_measures')}") + sd.filename_by_attribute("cc_mid_measures").parent.mkdir(exist_ok=True, parents=True) + with open(sd.filename_by_attribute("cc_mid_measures"), "w") as f: + json.dump(output_metrics_middle_slice, f, indent=4) - logger.info(f"Multiple slice post-processing results saved to {sd.filename_by_attribute('cc_postproc_results')}") + if sd.has_attribute("cc_measures"): + per_slice_output_dict = convert_numpy_to_json_serializable(per_slice_output_dict | additional_metrics) + sd.filename_by_attribute("cc_measures").parent.mkdir(exist_ok=True, parents=True) + # Save slice-wise postprocessing results to JSON + with open(sd.filename_by_attribute("cc_measures"), "w") as f: + json.dump(per_slice_output_dict, f, indent=4) + logger.info(f"Multiple slice post-processing results saved to {sd.filename_by_attribute('cc_measures')}") # save lta to fsaverage space - logger.info(f"Saving LTA to fsaverage space: {sd.filename_by_attribute('upright_lta')}") - lta.writeLTA( - sd.filename_by_attribute("upright_lta"), - orig2fsavg_ras2ras, - sd.filename_by_attribute("aseg_name"), - aseg_nib.header, - "fsaverage", - fsavg_header, - ) - - # save lta to standardized space (fsaverage + nodding + ac to center) - orig2standardized_ras2ras = orig.affine @ np.linalg.inv(standardized_to_orig_vox2vox) @ np.linalg.inv(orig.affine) - logger.info(f"Saving LTA to standardized space: {sd.filename_by_attribute('cc_orient_volume_lta')}") - lta.writeLTA( - sd.filename_by_attribute("cc_orient_volume_lta"), - orig2standardized_ras2ras, - sd.conf_name, - orig.header, - sd.conf_name, - orig.header, - ) - - for e in filter(lambda x: x and isinstance(x, Exception), (fut.exception() for fut in io_futures)): - logger.exception(e) - logger.info("CorpusCallosum analysis pipeline completed successfully") + if sd.has_attribute("upright_lta"): + sd.filename_by_attribute("cc_mid_measures").parent.mkdir(exist_ok=True, parents=True) + logger.info(f"Saving LTA to fsaverage space: {sd.filename_by_attribute('upright_lta')}") + io_futures.append(thread_executor().submit(write_lta, + sd.filename_by_attribute("upright_lta"), + orig2fsavg_ras2ras, + sd.filename_by_attribute("aseg_name"), + aseg_nib.header, + "fsaverage", + fsavg_header, + )) + + if sd.has_attribute("cc_orient_volume_lta"): + sd.filename_by_attribute("cc_orient_volume_lta").parent.mkdir(exist_ok=True, parents=True) + # save lta to standardized space (fsaverage + nodding + ac to center) + orig2standardized_ras2ras = orig.affine @ np.linalg.inv(standardized2orig_vox2vox) @ np.linalg.inv(orig.affine) + logger.info(f"Saving LTA to standardized space: {sd.filename_by_attribute('cc_orient_volume_lta')}") + io_futures.append(thread_executor().submit(write_lta, + sd.filename_by_attribute("cc_orient_volume_lta"), + orig2standardized_ras2ras, + sd.conf_name, + orig.header, + sd.conf_name, + orig.header, + )) + + # this waits for all io to finish + for fut in io_futures: + e = fut.exception() + if e and isinstance(e, Exception): + logger.exception(e) + shutdown_executors() + + duration = (perf_counter_ns() - start) / 1e9 + logger.info(f"CorpusCallosum analysis pipeline completed successfully in {duration:.2f} seconds.") if __name__ == "__main__": @@ -890,25 +979,25 @@ def main( slice_selection=options.slice_selection, qc_output_dir=options.qc_output_dir, num_thickness_points=options.num_thickness_points, - subdivisions=options.subdivisions, - subdivision_method=options.subdivision_method, + subdivisions=list(options.subdivisions), # default value is type _fmt_list (does not pickle) + subdivision_method=str(options.subdivision_method), # default value is type do not print (does not pickle) contour_smoothing=options.contour_smoothing, save_template_dir=options.save_template_dir, device=options.device, - upright_volume_path=options.upright_volume_path, - segmentation_path=options.segmentation_path, - postproc_results_path=options.postproc_results_path, - cc_markers_path=options.cc_markers_path, - upright_lta_path=options.upright_lta_path, - orient_volume_lta_path=options.orient_volume_lta_path, - surf_file_path=options.surf_file_path, - overlay_file_path=options.overlay_file_path, - cc_html_path=options.cc_html_path, - vtk_file_path=options.vtk_file_path, - orig_space_segmentation_path=options.orig_space_segmentation_path, - qc_image_path=options.qc_image_path, - thickness_image_path=options.thickness_image_path, - softlabels_cc_path=options.softlabels_cc_path, - softlabels_fn_path=options.softlabels_fn_path, - softlabels_background_path=options.softlabels_background_path, + upright_volume=options.upright_volume, + segmentation=options.segmentation, + cc_measures=options.cc_measures, + cc_mid_measures=options.cc_mid_measures, + upright_lta=options.upright_lta, + orient_volume_lta=options.orient_volume_lta, + cc_surf=options.cc_surf, + cc_thickness_overlay=options.thickness_overlay, + cc_html=options.cc_html, + cc_surf_vtk=options.cc_surf_vtk, + segmentation_in_orig=options.segmentation_in_orig, + qc_image=options.qc_image, + thickness_image=options.thickness_image, + softlabels_cc=options.softlabels_cc, + softlabels_fn=options.softlabels_fn, + softlabels_background=options.softlabels_background, ) diff --git a/CorpusCallosum/paint_cc_into_pred.py b/CorpusCallosum/paint_cc_into_pred.py index 7815f9a1..2a577d1b 100644 --- a/CorpusCallosum/paint_cc_into_pred.py +++ b/CorpusCallosum/paint_cc_into_pred.py @@ -17,13 +17,22 @@ import argparse import sys +from pathlib import Path +from typing import TypeVar, cast import nibabel as nib import numpy as np from numpy import typing as npt from scipy import ndimage +from CorpusCallosum.data.constants import FORNIX_LABEL, SUBSEGMENT_LABELS from FastSurferCNN.data_loader.conform import is_conform +from FastSurferCNN.reduce_to_aseg import reduce_to_aseg_and_save +from FastSurferCNN.utils.arg_types import path_or_none +from FastSurferCNN.utils.brainvolstats import mask_in_array +from FastSurferCNN.utils.parallel import thread_executor + +_T = TypeVar("_T", bound=np.number) HELPTEXT = """ Script to add corpus callosum segmentation (CC, FreeSurfer IDs 251-255) to @@ -49,32 +58,52 @@ def argument_parse(): """Create a command line interface and return command line options. """ + parser = make_parser() + + args = parser.parse_args() + + if args.input_cc is None or args.input_pred is None or args.output is None: + sys.exit("ERROR: Please specify input and output segmentations") + + return args + + +def make_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(usage=HELPTEXT) parser.add_argument( "--input_cc", "-in_cc", dest="input_cc", + type=Path, + required=True, help="path to input segmentation with Corpus Callosum (IDs 251-255 in FreeSurfer space)", ) parser.add_argument( "--input_pred", "-in_pred", dest="input_pred", + type=Path, + required=True, help="path to input segmentation Corpus Callosum should be added to.", ) parser.add_argument( "--output", "-out", dest="output", + type=Path, + required=True, help="path to output (input segmentation + added CC)", ) - - args = parser.parse_args() - - if args.input_cc is None or args.input_pred is None or args.output is None: - sys.exit("ERROR: Please specify input and output segmentations") - - return args + parser.add_argument( + "--reduce_to_aseg", + "-aseg", + dest="aseg", + type=path_or_none, + required=False, + help="optionally also reduce the resulting segmentation to aseg and save separately.", + default=None, + ) + return parser def paint_in_cc(pred: npt.NDArray[np.int_], @@ -98,7 +127,7 @@ def paint_in_cc(pred: npt.NDArray[np.int_], This function modifies the original array and does not create a copy. The CC labels (251-255) from aseg_cc are copied into pred. """ - cc_mask = (aseg_cc >= 251) & (aseg_cc <= 255) + cc_mask = mask_in_array(aseg_cc, SUBSEGMENT_LABELS) pred[cc_mask] = aseg_cc[cc_mask] return pred @@ -120,7 +149,7 @@ def correct_wm_ventricles( corrected_pred = aseg_cc.copy() # Get CC mask (labels 251-255) - cc_mask = (aseg_cc >= 251) & (aseg_cc <= 255) + cc_mask = mask_in_array(aseg_cc, SUBSEGMENT_LABELS) # Get left and right ventricle masks all_ventricle_mask = (aseg_cc == 4) | (aseg_cc == 43) @@ -135,10 +164,9 @@ def correct_wm_ventricles( # Process each slice independently for x in range(corrected_pred.shape[0]): - cc_slice = cc_mask[x, :, :] - #vent_slice = ventricle_mask[x, :, :] - all_wm_slice = all_wm_mask[x, :, :] - + cc_slice = cc_mask + #vent_slice = ventricle_mask + all_wm_slice = all_wm_mask if all_wm_slice.any() and cc_slice.any(): @@ -152,22 +180,19 @@ def correct_wm_ventricles( component_mask = labeled_wm == label # Check if this component is adjacent to (touches) the CC if np.any(component_mask & cc_dilated): - corrected_pred[x, :, :][component_mask] = 0 # Set to background + corrected_pred[x][component_mask] = 0 # Set to background - if fornix_mask[x, :, :].any(): - fornix_slice = fornix_mask[x, :, :] + if fornix_mask[x].any(): + fornix_slice = fornix_mask[x] # count WM labels overlapping with fornix left_wm_overlap = np.sum(fornix_slice & (aseg_cc == 2)) right_wm_overlap = np.sum(fornix_slice & (aseg_cc == 41)) - if left_wm_overlap > right_wm_overlap: - corrected_pred[x, :, :][fornix_slice] = 2 # Left WM - else: - corrected_pred[x, :, :][fornix_slice] = 41 # Right WM - + corrected_pred[x][fornix_slice] = 2 + (left_wm_overlap > right_wm_overlap) * 39 # Left WM / Right WM - vent_slice = all_ventricle_mask[x, :, :] + vent_slice = all_ventricle_mask + potential_fill = np.asarray([False]) if cc_slice.any() and vent_slice.any(): # Create binary masks for this slice cc_binary = cc_slice.astype(bool) @@ -181,57 +206,57 @@ def correct_wm_ventricles( # Find voxels that are adjacent to both CC and ventricle but not part of either potential_fill = (cc_dilated & vent_dilated) & ~(cc_binary | vent_binary) - # Only fill small gaps between CC and ventricle in inferior-superior direction - if potential_fill.any(): - for z in range(potential_fill.shape[1]): - potential_fill_line = potential_fill[:, z] - labeled_gaps, num_gaps = ndimage.label(potential_fill_line) - cc_line = cc_binary[:, z] - vent_line = vent_binary[:, z] - - for gap_label in range(1, num_gaps + 1): - gap_mask = labeled_gaps == gap_label - - # check that CC and ventricle are connected to the gap_mask - dilated_gap_mask = ndimage.binary_dilation(gap_mask, iterations=1) - if not np.any(cc_line & dilated_gap_mask): - continue - if not np.any(vent_line & dilated_gap_mask): - continue - - vent_label_location = np.where(vent_line & dilated_gap_mask)[0] - vent_label = corrected_pred[x, vent_label_location, z] + # Only fill small gaps between CC and ventricle in inferior-superior direction + if not potential_fill.any(): + for z in range(potential_fill.shape[1]): + potential_fill_line = potential_fill[:, z] + labeled_gaps, num_gaps = ndimage.label(potential_fill_line) + cc_line = cc_binary[:, z] + vent_line = vent_binary[:, z] + + for gap_label in range(1, num_gaps + 1): + gap_mask = labeled_gaps == gap_label + + # check that CC and ventricle are connected to the gap_mask + dilated_gap_mask = ndimage.binary_dilation(gap_mask, iterations=1) + if not np.any(cc_line & dilated_gap_mask): + continue + if not np.any(vent_line & dilated_gap_mask): + continue + + vent_label_location = np.where(vent_line & dilated_gap_mask)[0] + vent_label = corrected_pred[x, vent_label_location, z] + + if np.sum(gap_mask) > max_gap_vox: + continue + + corrected_pred[x, :, z][gap_mask & (corrected_pred[x, :, z] == 0)] = vent_label + + # Process gaps in z-direction (within each y-row) + for y in range(potential_fill.shape[0]): + potential_fill_line = potential_fill[y, :] + labeled_gaps, num_gaps = ndimage.label(potential_fill_line) + cc_line = cc_binary[y, :] + vent_line = vent_binary[y, :] + + for gap_label in range(1, num_gaps + 1): + gap_mask = labeled_gaps == gap_label + + # check that CC and ventricle are connected to the gap_mask + dilated_gap_mask = ndimage.binary_dilation(gap_mask, iterations=1) + if not np.any(cc_line & dilated_gap_mask): + continue + if not np.any(vent_line & dilated_gap_mask): + continue + + vent_label_location = np.where(vent_line & dilated_gap_mask)[0] + if len(vent_label_location) > 0: + vent_label = corrected_pred[x, y, vent_label_location[0]] # Take first match if np.sum(gap_mask) > max_gap_vox: continue - corrected_pred[x, :, z][gap_mask & (corrected_pred[x, :, z] == 0)] = vent_label - - # Process gaps in z-direction (within each y-row) - for y in range(potential_fill.shape[0]): - potential_fill_line = potential_fill[y, :] - labeled_gaps, num_gaps = ndimage.label(potential_fill_line) - cc_line = cc_binary[y, :] - vent_line = vent_binary[y, :] - - for gap_label in range(1, num_gaps + 1): - gap_mask = labeled_gaps == gap_label - - # check that CC and ventricle are connected to the gap_mask - dilated_gap_mask = ndimage.binary_dilation(gap_mask, iterations=1) - if not np.any(cc_line & dilated_gap_mask): - continue - if not np.any(vent_line & dilated_gap_mask): - continue - - vent_label_location = np.where(vent_line & dilated_gap_mask)[0] - if len(vent_label_location) > 0: - vent_label = corrected_pred[x, y, vent_label_location[0]] # Take first match - - if np.sum(gap_mask) > max_gap_vox: - continue - - corrected_pred[x, y, :][gap_mask & (corrected_pred[x, y, :] == 0)] = vent_label + corrected_pred[x, y, :][gap_mask & (corrected_pred[x, y, :] == 0)] = vent_label return corrected_pred @@ -241,52 +266,67 @@ def correct_wm_ventricles( # Command Line options are error checking done here options = argument_parse() - - print(f"Reading inputs: {options.input_cc} {options.input_pred}...") - cc_seg_image = nib.load(options.input_cc) + cc_seg_image = cast(nib.analyze.SpatialImage, nib.load(options.input_cc)) cc_seg_data = np.asanyarray(cc_seg_image.dataobj) - aseg_image = nib.load(options.input_pred) + aseg_image = cast(nib.analyze.SpatialImage, nib.load(options.input_pred)) aseg_data = np.asanyarray(aseg_image.dataobj) cc_conformed = is_conform(cc_seg_image, vox_size=None, img_size=None, verbose=False) - pred_conformed = is_conform(aseg_image, vox_size=None, img_size=None, verbose=False) + pred_conformed = is_conform(aseg_image, vox_size=None, img_size=None, dtype=np.integer, verbose=False) if not cc_conformed: - print("Warning: CC input image is not conformed (LIA orientation, uint8 dtype). \ - Please conform the image using the conform.py script.") + sys.exit("Error: CC input image is not conformed (LIA orientation, uint8 dtype). \ + Please conform the image using the conform.py script.") if not pred_conformed: - print("Warning: Prediction input image is not conformed (LIA orientation, uint8 dtype). \ - Please conform the image using the conform.py script.") - - # Count initial labels - initial_cc = np.sum((aseg_data >= 251) & (aseg_data <= 255)) - initial_fornix = np.sum(aseg_data == 250) - initial_wm = np.sum((aseg_data == 2) | (aseg_data == 41)) - print(f"Initial segmentation: CC={initial_cc}, Fornix={initial_fornix}, WM={initial_wm}") + sys.exit("Error: Prediction input image is not conformed (LIA orientation, integer dtype). \ + Please conform the image using the conform.py script.") + if not np.allclose(cc_conformed, pred_conformed): + sys.exit("Error: The affine matrices of the aseg and the corpus callosum images are not the same.") # Paint CC into prediction pred_with_cc = paint_in_cc(aseg_data, cc_seg_data) - after_paint_cc = np.sum((pred_with_cc >= 251) & (pred_with_cc <= 255)) - print(f"After painting CC: {after_paint_cc} CC voxels added") # Apply WM and ventricle corrections print("Applying white matter and ventricle corrections...") - fornix_mask = cc_seg_data == 250 + fornix_mask = cc_seg_data == FORNIX_LABEL voxel_size = tuple(aseg_image.header.get_zooms()) pred_corrected = correct_wm_ventricles(aseg_data, fornix_mask, voxel_size) + print(f"Writing segmentation with corpus callosum to: {options.output}") + pred_with_cc_fin = nib.MGHImage(pred_corrected, aseg_image.affine, aseg_image.header) + io_fut = thread_executor().submit(pred_with_cc_fin.to_filename, options.output) + + if options.aseg is not None: + rta_fut = thread_executor().submit( + reduce_to_aseg_and_save, + pred_corrected, + aseg_image.affine, + aseg_image.header, + options.aseg, + ) + else: + rta_fut = None + + # Count initial labels + initial_cc = np.sum(mask_in_array(aseg_data, SUBSEGMENT_LABELS)) + initial_fornix = np.sum(aseg_data == FORNIX_LABEL) + initial_wm = np.sum((aseg_data == 2) | (aseg_data == 41)) + print(f"Initial segmentation: CC={initial_cc}, Fornix={initial_fornix}, WM={initial_wm}") + + after_paint_cc = np.sum(mask_in_array(pred_with_cc, SUBSEGMENT_LABELS)) + print(f"After painting CC: {after_paint_cc} CC voxels added") + # Count final labels - final_cc = np.sum((pred_corrected >= 251) & (pred_corrected <= 255)) - final_fornix = np.sum(pred_corrected == 250) + final_cc = np.sum(mask_in_array(pred_corrected, SUBSEGMENT_LABELS)) + final_fornix = np.sum(pred_corrected == FORNIX_LABEL) final_wm = np.sum((pred_corrected == 2) | (pred_corrected == 41)) final_ventricles = np.sum((pred_corrected == 4) | (pred_corrected == 43)) print(f"Final segmentation: CC={final_cc}, Fornix={final_fornix}, WM={final_wm}, Ventricles={final_ventricles}") print(f"Changes: CC +{final_cc-initial_cc}, Fornix {final_fornix-initial_fornix}, WM {final_wm-initial_wm}") - print(f"Writing segmentation with corpus callosum to: {options.output}") - pred_with_cc_fin = nib.MGHImage(pred_corrected, aseg_image.affine, aseg_image.header) - pred_with_cc_fin.to_filename(options.output) + if rta_fut is not None: + _ = rta_fut.result() sys.exit(0) diff --git a/CorpusCallosum/registration/mapping_helpers.py b/CorpusCallosum/registration/mapping_helpers.py index a9a1a7cc..dcecd9fd 100644 --- a/CorpusCallosum/registration/mapping_helpers.py +++ b/CorpusCallosum/registration/mapping_helpers.py @@ -6,7 +6,9 @@ from numpy import typing as npt from scipy.ndimage import affine_transform +from CorpusCallosum.data.constants import CC_LABEL, FORNIX_LABEL from FastSurferCNN.utils import logging +from FastSurferCNN.utils.parallel import thread_executor logger = logging.get_logger(__name__) @@ -292,31 +294,28 @@ def make_affine(simpleITKImage: 'sitk.Image') -> npt.NDArray[float]: def map_softlabels_to_orig( - outputs_soft: npt.NDArray[float], + cc_fn_softlabels: npt.NDArray[float], orig_fsaverage_vox2vox: npt.NDArray[float], orig: nib.analyze.SpatialImage, - slices_to_analyze: int, orig_space_segmentation_path: str | Path | None = None, fsaverage_middle: int = 128, - subdivision_mask: npt.NDArray[int] | None = None + cc_subseg_midslice: npt.NDArray[int] | None = None ) -> npt.NDArray[int]: """Map soft labels back to original image space and apply post-processing. Parameters ---------- - outputs_soft : np.ndarray + cc_fn_softlabels : np.ndarray Soft label predictions. orig_fsaverage_vox2vox : np.ndarray Original to fsaverage space transformation. orig : nibabel.analyze.SpatialImage Original image. - slices_to_analyze : int - Number of slices to analyze. orig_space_segmentation_path : str or Path, optional Path to save segmentation in original space. fsaverage_middle : int, default=128 Middle slice index in fsaverage space. - subdivision_mask : npt.NDArray[int], optional + cc_subseg_midslice : npt.NDArray[int], optional Mask for subdividing regions. Returns @@ -327,64 +326,50 @@ def map_softlabels_to_orig( Notes ----- The function: - 1. Pads soft labels to original image size - 2. Transforms each label channel separately - 3. Applies post-processing if needed - 4. Optionally saves result to file - - TODO: This could be optimized by padding after the transform + 1. Transforms background, cc, and fornix label channels separately. + 2. Transform CC subsegmentation from midslice to orig and paint into segmentation if `cc_subseg_midslice` is passed. + 4. Saves result to `orig_space_segmentation_path` if passed. """ - + slices_to_analyze = cc_fn_softlabels.shape[0] # map softlabels to original image - pad_lr = (fsaverage_middle - slices_to_analyze // 2, fsaverage_middle + slices_to_analyze // 2 + 1) - pad_tuples = (pad_lr,) + ((0, 0),) * (orig.ndim - 1) - softlabels_transformed = [] - for i in range(outputs_soft.shape[-1]): - # pad to original image size - outputs_soft_padded = np.pad(outputs_soft[..., i], pad_tuples) - - s = affine_transform( - outputs_soft_padded, - orig_fsaverage_vox2vox, - output_shape=orig.shape, - order=1, - cval=1.0 if i == 0 else 0.0, - ) - softlabels_transformed.append(s) + slab2fsaverage_vox2vox = np.eye(4) + slab2fsaverage_vox2vox[0, 3] = -(fsaverage_middle - slices_to_analyze // 2) + slab2orig_vox2vox = orig_fsaverage_vox2vox @ slab2fsaverage_vox2vox - softlabels_orig_space = np.stack(softlabels_transformed, axis=-1) + def _map_softlabel_to_orig(i: int, data: np.ndarray) -> np.ndarray: + return affine_transform(data, slab2orig_vox2vox, output_shape=orig.shape, order=1, cval=float(i == 0)) - # apply softmax to softlabels_orig_space - exp_orig_space = np.exp(softlabels_orig_space) - softlabels_orig_space = exp_orig_space / np.sum(exp_orig_space, axis=-1, keepdims=True) + _softlabels = np.moveaxis(cc_fn_softlabels, -1, 0) + softlabels_transformed = thread_executor().map(_map_softlabel_to_orig, *zip(*enumerate(_softlabels), strict=True)) - segmentation_orig_space = np.argmax(softlabels_orig_space, axis=-1) + softlabels_orig_space = np.stack(list(softlabels_transformed), axis=-1) + seg_orig_space = np.argmax(softlabels_orig_space, axis=-1) + # map to freesurfer labels + seg_lut = np.asarray([0, CC_LABEL, FORNIX_LABEL]) + seg_orig_space = seg_lut[seg_orig_space] - if subdivision_mask is not None: - # repeat subdivision mask for shape 0 of orig - subdivision_mask = np.repeat(subdivision_mask[np.newaxis], orig.shape[0], axis=0) + if cc_subseg_midslice is not None: # map subdivision mask to orig space - subdivision_mask_orig_space = affine_transform( - subdivision_mask, + midslice2fsaverage_vox2vox = np.eye(4) + midslice2fsaverage_vox2vox[0, 3] = -fsaverage_middle + cc_subseg_orig_space = affine_transform( + cc_subseg_midslice[None], orig_fsaverage_vox2vox, output_shape=orig.shape, order=0, + mode="nearest", ) - mask = segmentation_orig_space == 1 - segmentation_orig_space[mask] *= subdivision_mask_orig_space[mask] - - seg_lut = np.asarray([0, 192, 250]) - segmentation_orig_space = seg_lut[segmentation_orig_space] + seg_orig_space = np.where(seg_orig_space == CC_LABEL, cc_subseg_orig_space, seg_orig_space) if orig_space_segmentation_path is not None: logger.info(f"Saving segmentation in original space to {orig_space_segmentation_path}") nib.save( - nib.MGHImage(segmentation_orig_space, orig.affine, orig.header), + nib.MGHImage(seg_orig_space, orig.affine, orig.header), orig_space_segmentation_path, ) - return segmentation_orig_space + return seg_orig_space def interpolate_midplane( diff --git a/CorpusCallosum/segmentation/segmentation_inference.py b/CorpusCallosum/segmentation/segmentation_inference.py index 6aed9ce4..cb9a5a85 100644 --- a/CorpusCallosum/segmentation/segmentation_inference.py +++ b/CorpusCallosum/segmentation/segmentation_inference.py @@ -25,7 +25,7 @@ from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults from FastSurferCNN.download_checkpoints import main as download_checkpoints from FastSurferCNN.models.networks import FastSurferVINN -from FastSurferCNN.utils.common import thread_executor as executor +from FastSurferCNN.utils.parallel import thread_executor def load_model(device: torch.device | None = None) -> FastSurferVINN: @@ -179,7 +179,7 @@ def _load(label_path: str | Path) -> int: return last_nonzero - first_nonzero else: return label_img.shape[0] - label_widths = executor().map(_load, data["label"]) + label_widths = thread_executor().map(_load, data["label"]) return images, ac_centers, pc_centers, label_widths, labels, subj_ids diff --git a/CorpusCallosum/shape/cc_mesh.py b/CorpusCallosum/shape/cc_mesh.py index ed41150e..ef372073 100644 --- a/CorpusCallosum/shape/cc_mesh.py +++ b/CorpusCallosum/shape/cc_mesh.py @@ -22,11 +22,13 @@ import numpy as np import plotly.graph_objects as go import scipy.interpolate +from plotly.io import write_html as plotly_write_html from scipy.ndimage import gaussian_filter1d import FastSurferCNN.utils.logging as logging from CorpusCallosum.shape.cc_endpoint_heuristic import smooth_contour -from CorpusCallosum.shape.cc_thickness import HiddenPrints, make_mesh_from_contour +from CorpusCallosum.shape.cc_thickness import make_mesh_from_contour +from FastSurferCNN.utils.common import suppress_stdout try: from pyrr import Matrix44 @@ -444,14 +446,14 @@ def plot_mesh( if output_path is not None: self.__make_parent_folder(output_path) - fig.write_html(output_path) # Save as interactive HTML + plotly_write_html(fig, output_path, include_plotlyjs="cdn") # Save as interactive HTML else: # For non-interactive display, save to a temporary HTML and open in browser import tempfile import webbrowser temp_path = Path(tempfile.gettempdir()) / "cc_mesh_plot.html" - fig.write_html(temp_path) + plotly_write_html(fig, temp_path, include_plotlyjs="cdn") webbrowser.open(f"file://{temp_path}") def get_contour_edge_lengths(self, contour_idx: int) -> np.ndarray: @@ -556,7 +558,8 @@ def _create_levelpaths( 3. Solves Poisson equation for level sets 4. Extracts level paths and interpolates thickness values """ - with HiddenPrints(): + + with suppress_stdout(): cc_tria = lapy.TriaMesh(points, trias) # extract boundary curve bdr = np.array(cc_tria.boundary_loops()[0]) @@ -576,7 +579,7 @@ def _create_levelpaths( dcond[iidx1 + 1 : iidx2] = -1 # Extract path - with HiddenPrints(): + with suppress_stdout(): fem = lapy.Solver(cc_tria) vfunc = fem.poisson(0, (bdr, dcond)) if num_points is not None: diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/cc_postprocessing.py index cdda1b96..a4ed63cf 100644 --- a/CorpusCallosum/shape/cc_postprocessing.py +++ b/CorpusCallosum/shape/cc_postprocessing.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import concurrent.futures -from pathlib import Path +from functools import partial from typing import Literal, get_args import numpy as np @@ -31,12 +31,12 @@ transform_to_acpc_standard, ) from CorpusCallosum.shape.cc_thickness import cc_thickness, convert_to_ras -from CorpusCallosum.utils.utils import HiddenPrints from CorpusCallosum.visualization.visualization import plot_contours -from FastSurferCNN.utils.common import SubjectDirectory, update_docstring -from FastSurferCNN.utils.common import thread_executor as executor +from FastSurferCNN.utils.common import SubjectDirectory, suppress_stdout, update_docstring +from FastSurferCNN.utils.parallel import process_executor, thread_executor SubdivisionMethod = Literal["shape", "vertical", "angular", "eigenvector"] +SliceSelection = Literal["middle", "all"] | int logger = logging.get_logger(__name__) @@ -47,61 +47,6 @@ LIA_ORIENTATION[2,1] = -1 -@update_docstring(SubdivisionMethod=str(get_args(SubdivisionMethod))[1:-1]) -def async_create_visualization( - subdivision_method: SubdivisionMethod, - result: dict, - midslices_data: np.ndarray, - output_image_path: str | Path, - ac_coords: np.ndarray, - pc_coords: np.ndarray, - vox_size: float, - title_suffix: str = "", -) -> concurrent.futures.Future: - """Create visualization plots based on subdivision method. - - Parameters - ---------- - subdivision_method : {SubdivisionMethod} - The subdivision method being used. - result : dict - Dictionary containing processing results with split_contours. - midslices_data : np.ndarray - Slice data for visualization. - output_image_path : Path, str - Path to save visualization. - ac_coords : np.ndarray - AC coordinates. - pc_coords : np.ndarray - PC coordinates. - vox_size : float - Voxel size in mm. - title_suffix : str, optional - Additional text to append to the title, by default "". - - Returns - ------- - multiprocessing.Process - Process object for background execution. - """ - title = f"CC Subsegmentation by {subdivision_method} {title_suffix}" - - args_dict = { - "debug": True, - "transformed": midslices_data, - "split_contours": result["split_contours"], - "midline_equidistant": result["midline_equidistant"], - "levelpaths": result["levelpaths"], - "output_path": output_image_path, - "ac_coords": ac_coords, - "pc_coords": pc_coords, - "vox_size": vox_size, - "title": title, - } - - return executor().submit(plot_contours, **args_dict) - - def create_slice_affine(temp_seg_affine: np.ndarray, slice_idx: int, fsaverage_middle: int) -> np.ndarray: """Create slice-specific affine transformation matrix. @@ -127,7 +72,7 @@ def create_slice_affine(temp_seg_affine: np.ndarray, slice_idx: int, fsaverage_m @update_docstring(SubdivisionMethod=str(get_args(SubdivisionMethod))[1:-1]) def recon_cc_surf_measures_multi( segmentation: np.ndarray, - slice_selection: str, + slice_selection: SliceSelection, temp_seg_affine: np.ndarray, midslices: np.ndarray, ac_coords: np.ndarray, @@ -137,7 +82,6 @@ def recon_cc_surf_measures_multi( subdivision_method: SubdivisionMethod, contour_smoothing: float, subject_dir: SubjectDirectory, - qc_image_path: str | None = None, vox_size: tuple[float, float, float] | None = None, vox2ras_tkr: np.ndarray | None = None, ) -> tuple[list, list[concurrent.futures.Future]]: @@ -167,8 +111,6 @@ def recon_cc_surf_measures_multi( Gaussian sigma for contour smoothing. subject_dir : SubjectDirectory The SubjectDirectory object managing file names in the subject directory. - qc_image_path : Path, str, optional - Path for QC visualization image. vox_size : 3-tuple of floats, optional Voxel size in millimeters (x, y, z). vox2ras_tkr : np.ndarray, optional @@ -184,104 +126,90 @@ def recon_cc_surf_measures_multi( slice_results = [] io_futures = [] - if slice_selection == "middle": - cc_mesh = CCMesh(num_slices=1) - cc_mesh.set_acpc_coords(ac_coords, pc_coords) - cc_mesh.set_resolution(vox_size[0]) - - # Process only the middle slice - slice_idx = segmentation.shape[0] // 2 - slice_affine = create_slice_affine(temp_seg_affine, slice_idx, FSAVERAGE_MIDDLE) - - result, contour_with_thickness, *endpoint_idxs = recon_cc_surf_measure( - segmentation, - slice_idx, - ac_coords, - pc_coords, - slice_affine, - num_thickness_points, - subdivisions, - subdivision_method, - contour_smoothing, - vox_size[0], + if subdivision_method == "angular" and not np.allclose(np.diff(subdivisions), np.diff(subdivisions)[0]): + raise ValueError( + f"Angular subdivision method (Hampel) only supports equidistant subdivision, " + f"but got: {subdivisions}. No measures are computed.", ) - cc_mesh.add_contour(0, *contour_with_thickness, start_end_idx=endpoint_idxs) - - if result is not None and qc_image_path is not None: - slice_results.append(result) - # Create visualization - logger.info(f"Saving segmentation qc image to {qc_image_path}") - io_futures.append(async_create_visualization( - subdivision_method, - result, - midslices, - qc_image_path, - ac_coords, - pc_coords, - vox_size[0], - )) - else: + _each_slice = partial(recon_cc_surf_measure, + segmentation, + ac_coords=ac_coords, + pc_coords=pc_coords, + num_thickness_points=num_thickness_points, + subdivisions=subdivisions, + subdivision_method=subdivision_method, + contour_smoothing=contour_smoothing, + vox_size=vox_size[0], + ) + + # Process multiple slices or specific slice + if slice_selection == "middle": + num_slices = 1 + # Process only the middle slice + slice_iterator = [segmentation.shape[0] // 2] + elif slice_selection == "all": num_slices = segmentation.shape[0] - cc_mesh = CCMesh(num_slices=num_slices) - cc_mesh.set_acpc_coords(ac_coords, pc_coords) - cc_mesh.set_resolution(vox_size[0]) - - # Process multiple slices or specific slice - if slice_selection == "all": - start_slice = 0 - end_slice = segmentation.shape[0] - else: # specific slice number - slice_idx = int(slice_selection) - start_slice = slice_idx - end_slice = slice_idx + 1 - - for slice_idx in range(start_slice, end_slice): - logger.info(f"Calculating CC measurements for slice {slice_idx+1} of {end_slice-start_slice}") - - # Update affine for this slice - slice_affine = create_slice_affine(temp_seg_affine, slice_idx, FSAVERAGE_MIDDLE) - - # Process this slice - result, contour_with_thickness, *endpoint_idxs = recon_cc_surf_measure( - segmentation, - slice_idx, - ac_coords, - pc_coords, - slice_affine, - num_thickness_points, - subdivisions, - subdivision_method, - contour_smoothing, - vox_size[0], + start_slice = 0 + end_slice = segmentation.shape[0] + slice_iterator = range(start_slice, end_slice) + else: # specific slice number + num_slices = 1 + slice_iterator = [int(slice_selection)] + + it_affine = map(partial(create_slice_affine, temp_seg_affine, fsaverage_middle=FSAVERAGE_MIDDLE), slice_iterator) + + iterator = process_executor().map(_each_slice, iter(slice_iterator), it_affine, chunksize=1) + cc_mesh = CCMesh(num_slices=num_slices) + cc_mesh.set_acpc_coords(ac_coords, pc_coords) + cc_mesh.set_resolution(vox_size[0]) + + def _yield_iterator(): + for _slice_idx in slice_iterator: + try: + yield _slice_idx, *next(iterator) + except ValueError as e: + logger.error(f"Slice {_slice_idx} failed with error: {e}") + logger.exception(e) + except StopIteration: + logger.error(f"Unexpectedly skipping slice {_slice_idx} in CC surfaces.") + return + + for i, (slice_idx, result, contour_with_thickness, *endpoint_idxs) in enumerate(_yield_iterator()): + # insert + progress = f" ({i+1} of {num_slices})" if num_slices > 1 else "" + logger.info(f"Calculating CC measurements for slice {slice_idx+1}{progress}") + cc_mesh.add_contour(slice_idx, *contour_with_thickness, start_end_idx=endpoint_idxs) + if result is None: + continue + + slice_results.append(result) + + if logger.getEffectiveLevel() <= logging.INFO and subject_dir.has_attribute("cc_qc_image"): + qc_img = subject_dir.filename_by_attribute("cc_qc_image") + if logger.getEffectiveLevel() <= logging.DEBUG: + qc_img = (qc_img.parent / f"{qc_img.stem}_slice_{slice_idx}{qc_img.suffix}").with_suffix(".png") + + if logger.getEffectiveLevel() <= logging.DEBUG or slice_idx == num_slices // 2: + logger.info(f"Saving segmentation qc image to {qc_img}") + + current_slice_in_volume = midslices.shape[0] // 2 - num_slices // 2 + slice_idx + # Create visualization for this slice + io_futures.append( + thread_executor().submit( + plot_contours, + transformed=midslices[current_slice_in_volume:current_slice_in_volume+1], + split_contours=result["split_contours"], + midline_equidistant=result["midline_equidistant"], + levelpaths=result["levelpaths"], + output_path=qc_img, + ac_coords=ac_coords, + pc_coords=pc_coords, + vox_size=vox_size[0], + title=f"CC Subsegmentation by {subdivision_method} (Slice {slice_idx})", + ) ) - # insert - cc_mesh.add_contour(slice_idx, *contour_with_thickness, start_end_idx=endpoint_idxs) - - if result is not None: - slice_results.append(result) - - if logger.getEffectiveLevel() <= logging.INFO and subject_dir.has_attribute("cc_qc_image"): - qc_img = subject_dir.filename_by_attribute("cc_qc_image") - if logger.getEffectiveLevel() <= logging.DEBUG: - qc_img = (qc_img.parent / f"{qc_img.stem}_slice_{slice_idx}{qc_img.suffix}").with_suffix(".png") - - if logger.getEffectiveLevel() <= logging.DEBUG or slice_idx == num_slices // 2: - logger.info(f"Saving segmentation qc image to {qc_img}") - - current_slice_in_volume = midslices.shape[0] // 2 - num_slices // 2 + slice_idx - # Create visualization for this slice - io_futures.append(async_create_visualization( - subdivision_method, - result, - midslices[current_slice_in_volume:current_slice_in_volume+1], - qc_img, - ac_coords, - pc_coords, - vox_size[0], - f" (Slice {slice_idx})", - )) if subject_dir.has_attribute("save_template_dir"): template_dir = subject_dir.filename_by_attribute("save_template_dir") @@ -289,38 +217,49 @@ def recon_cc_surf_measures_multi( template_dir.mkdir(parents=True, exist_ok=True) logger.info("Saving template files (contours.txt, thickness_values.txt, " f"thickness_measurement_points.txt) to {template_dir}") - cc_mesh.save_contours(template_dir / "contours.txt") - cc_mesh.save_thickness_values(template_dir / "thickness_values.txt") - cc_mesh.save_thickness_measurement_points(template_dir / "thickness_measurement_points.txt") - - - if len(cc_mesh.contours) > 1 and subject_dir.has_attribute("cc_html"): + for fut in [ + thread_executor().submit(cc_mesh.save_contours, template_dir / "contours.txt"), + thread_executor().submit(cc_mesh.save_thickness_values, template_dir / "thickness_values.txt"), + thread_executor().submit(cc_mesh.save_thickness_measurement_points, + template_dir / "thickness_measurement_points.txt"), + ]: + if fut.exception(): + logger.exception(fut.exception()) + + mesh_outputs = ("html", "mesh", "thickness_overlay", "surf", "thickness_image") + if len(cc_mesh.contours) > 1 and any(subject_dir.has_attribute(f"cc_{n}") for n in mesh_outputs): cc_mesh.fill_thickness_values() cc_mesh.create_mesh() cc_mesh.smooth_(1) - logger.info(f"Saving CC 3D visualization to {subject_dir.filename_by_attribute('cc_html')}") - cc_mesh.plot_mesh(output_path=subject_dir.filename_by_attribute("cc_html"), show_mesh_edges=True) + if subject_dir.has_attribute("cc_html"): + logger.info(f"Saving CC 3D visualization to {subject_dir.filename_by_attribute('cc_html')}") + io_futures.append(thread_executor().submit( + cc_mesh.plot_mesh, + output_path=subject_dir.filename_by_attribute("cc_html"), + show_mesh_edges=True, + )) if subject_dir.has_attribute("cc_mesh"): vtk_file_path = subject_dir.filename_by_attribute("cc_mesh") logger.info(f"Saving vtk file to {vtk_file_path}") - cc_mesh.write_vtk(vtk_file_path) + io_futures.append(thread_executor().submit(cc_mesh.write_vtk, vtk_file_path)) cc_mesh.to_fs_coordinates(vox2ras_tkr=vox2ras_tkr, vox_size=vox_size) - if subject_dir.has_attribute("overlay_file"): - overlay_file_path = subject_dir.filename_by_attribute("overlay_file") + if subject_dir.has_attribute("cc_thickness_overlay"): + overlay_file_path = subject_dir.filename_by_attribute("cc_thickness_overlay") logger.info(f"Saving overlay file to {overlay_file_path}") - cc_mesh.write_overlay(overlay_file_path) + io_futures.append(thread_executor().submit(cc_mesh.write_overlay, overlay_file_path)) - if subject_dir.has_attribute("cc_surf_file"): - surf_file_path = subject_dir.filename_by_attribute("cc_surf_file") + if subject_dir.has_attribute("cc_surf"): + surf_file_path = subject_dir.filename_by_attribute("cc_surf") logger.info(f"Saving surf file to {surf_file_path}") - cc_mesh.write_fssurf(surf_file_path) + io_futures.append(thread_executor().submit(cc_mesh.write_fssurf, surf_file_path)) - if subject_dir.has_attribute("thickness_image"): - thickness_image_path = subject_dir.filename_by_attribute("thickness_image") + if subject_dir.has_attribute("cc_thickness_image"): + thickness_image_path = subject_dir.filename_by_attribute("cc_thickness_image") logger.info(f"Saving thickness image to {thickness_image_path}") - with HiddenPrints(): + # note: suppress_stdout is not thread-safe! But it works fine, if only one thread uses it... + with suppress_stdout(): cc_mesh.snap_cc_picture(thickness_image_path) @@ -334,14 +273,14 @@ def recon_cc_surf_measures_multi( def recon_cc_surf_measure( segmentation: np.ndarray, slice_idx: int, + affine: np.ndarray, ac_coords: np.ndarray, pc_coords: np.ndarray, - affine: np.ndarray, num_thickness_points: int, subdivisions: list[float], subdivision_method: SubdivisionMethod, contour_smoothing: float, - vox_size: float + vox_size: float, ) -> tuple[dict[str, float | int | np.ndarray | list[float]], np.ndarray, int, int]: """Reconstruct surfaces and compute measures for a single slice for the corpus callosum. @@ -351,12 +290,12 @@ def recon_cc_surf_measure( 3D segmentation array. slice_idx : int Index of the slice to process. + affine : np.ndarray + 4x4 affine transformation matrix. ac_coords : np.ndarray Anterior commissure coordinates. pc_coords : np.ndarray Posterior commissure coordinates. - affine : np.ndarray - 4x4 affine transformation matrix. num_thickness_points : int Number of points for thickness estimation. subdivisions : list[float] @@ -447,9 +386,10 @@ def recon_cc_surf_measure( areas, split_contours = subdivide_contour(contour_acpc, subdivisions, plot=False) elif subdivision_method == "angular": if not np.allclose(np.diff(subdivisions), np.diff(subdivisions)[0]): - logger.error("Error: Angular subdivision method (Hampel) only supports equidistant subdivision, " - f"but got: {subdivisions}. No measures are computed.") - return {}, contour_with_thickness, *endpoint_idxs + raise ValueError( + f"Angular subdivision method (Hampel) only supports equidistant subdivision, " + f"but got: {subdivisions}. No measures are computed.", + ) areas, split_contours = hampel_subdivide_contour(contour_acpc, num_rays=len(subdivisions), plot=False) elif subdivision_method == "eigenvector": pt0, pt1 = get_primary_eigenvector(contour_acpc) diff --git a/CorpusCallosum/shape/cc_thickness.py b/CorpusCallosum/shape/cc_thickness.py index a8b0d6e7..69bfff04 100644 --- a/CorpusCallosum/shape/cc_thickness.py +++ b/CorpusCallosum/shape/cc_thickness.py @@ -19,7 +19,7 @@ from lapy.diffgeo import compute_rotated_f from meshpy import triangle -from CorpusCallosum.utils.utils import HiddenPrints +from FastSurferCNN.utils.common import suppress_stdout def compute_curvature(path: np.ndarray) -> np.ndarray: @@ -275,20 +275,7 @@ def make_mesh_from_contour( of the contour. The contour must not have duplicate points. """ - facets = np.vstack( - ( - np.arange(len(contour_2d)), - ((np.arange(len(contour_2d)) + 1) % len(contour_2d)), - ) - ).T - - # plot vertices and facets - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots(figsize=(10, 8)) - # ax.scatter(contour_2d[:,0], contour_2d[:,1], label='Contour') - # ax.plot(contour_2d[:,0], contour_2d[:,1], 'k-', label='Contour') - # ax.plot(contour_2d[facets[:,0],0], contour_2d[facets[:,0],1], 'r-', label='Facets') - # plt.show() + facets = np.vstack((np.arange(len(contour_2d)), ((np.arange(len(contour_2d)) + 1) % len(contour_2d)))).T # use meshpy to create mesh info = triangle.MeshInfo() @@ -296,7 +283,7 @@ def make_mesh_from_contour( info.set_facets(facets) # NOTE: crashes if contour has duplicate points !! mesh = triangle.build( - info, max_volume=max_volume, min_angle=min_angle, verbose=verbose + info, max_volume=max_volume, min_angle=min_angle, verbose=verbose, ) mesh_points = np.array(mesh.points) @@ -342,93 +329,56 @@ def cc_thickness( # standardize contour indices, to get consistent levelpath directions contour_2d, anterior_endpoint_idx, posterior_endpoint_idx = set_contour_zero_idx( - contour_2d, anterior_endpoint_idx, anterior_endpoint_idx, posterior_endpoint_idx + contour_2d, anterior_endpoint_idx, anterior_endpoint_idx, posterior_endpoint_idx, ) mesh_points, mesh_trias = make_mesh_from_contour(contour_2d) - # plot mesh points with index next to point - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots(figsize=(10, 8)) - # ax.plot(mesh_points[:,0], mesh_points[:,1], label='Mesh Points') - # for i in range(len(mesh_points)): - # ax.text(mesh_points[i,0], mesh_points[i,1], str(i), fontsize=7) - # plt.show() - # make points 3D by appending z=0 mesh_points3d = np.append(mesh_points, np.zeros((mesh_points.shape[0], 1)), axis=1) # compute poisson - with HiddenPrints(): + with suppress_stdout(): tria = TriaMesh(mesh_points3d, mesh_trias) - # extract boundary curve - bdr = np.array(tria.boundary_loops()[0]) - - # find index of endpoints in bdr list - iidx1 = np.where(bdr == anterior_endpoint_idx)[0][0] - iidx2 = np.where(bdr == posterior_endpoint_idx)[0][0] - - # create boundary condition (0 at endpoints, -1 on one side, 1 on the other): - if iidx1 > iidx2: - tmp = iidx2 - iidx2 = iidx1 - iidx1 = tmp - dcond = np.ones(bdr.shape) - dcond[iidx1] = 0 - dcond[iidx2] = 0 - dcond[iidx1 + 1 : iidx2] = -1 - - # Extract path - with HiddenPrints(): + # extract boundary curve + bdr = np.array(tria.boundary_loops()[0]) + + # find index of endpoints in bdr list + iidx1 = np.where(bdr == anterior_endpoint_idx)[0][0] + iidx2 = np.where(bdr == posterior_endpoint_idx)[0][0] + + # create boundary condition (0 at endpoints, -1 on one side, 1 on the other): + if iidx1 > iidx2: + iidx1, iidx2 = iidx2, iidx1 + dcond = np.ones(bdr.shape) + dcond[iidx1] = 0 + dcond[iidx2] = 0 + dcond[iidx1 + 1 : iidx2] = -1 + + # Extract path fem = Solver(tria) vfunc = fem.poisson(0, (bdr, dcond)) - level = 0 - midline_equidistant, midline_length = tria.level_path( - vfunc, level, n_points=n_points + 2 - ) - midline_equidistant = midline_equidistant[:, :2] + midline_equidistant, midline_length = tria.level_path(vfunc, level=0., n_points=n_points + 2) + midline_equidistant = midline_equidistant[:, :2] - # try: - with HiddenPrints(): gf = compute_rotated_f(tria, vfunc) - # except Exception as e: - # Lot contour and path - # import matplotlib.pyplot as plt - # import matplotlib.tri as tri - # fig, ax = plt.subplots(figsize=(10, 8)) - # # Plot contours - # ax.plot(contour_2d[:,0], contour_2d[:,1], 'k-', label='Contour', marker='o', markersize=3) - # ax.plot(midline_equidistant[:,0], midline_equidistant[:,1], 'g-', label='Level0', marker='o', markersize=2) - # # plot mesh - # mtpltlb_tria = tri.Triangulation(tria.v[:,0], tria.v[:,1], triangles=tria.t) - # ax.triplot(mtpltlb_tria, 'k-', alpha=0.2, linewidth=0.5) - # # Plot final endpoint estimates - # ax.plot(contour_2d[:,0][anterior_endpoint_idx], contour_2d[:,1][anterior_endpoint_idx], 'r*', - # markersize=15, label='Final estimate') - # ax.plot(contour_2d[:,0][posterior_endpoint_idx], contour_2d[:,1][posterior_endpoint_idx], 'r*', - # markersize=15, label='Final estimate') - # ax.legend() - # #ax.set_title(f'Subject: {subj_id}') - # plt.show() - - # interpolate midline to get levels to evaluate - gf_interp = scipy.interpolate.griddata( - tria.v[:, 0:2], gf, midline_equidistant[:, 0:2], method="cubic" - ) - # get levels to evaluate - # level_length = tria.level_length(gf, gf_interp) + # interpolate midline to get levels to evaluate + level_of_rotated_laplace = scipy.interpolate.griddata( + tria.v[:, 0:2], gf, midline_equidistant[:, 0:2], method="cubic", + ) + # get levels to evaluate levelpaths = [] levelpath_lengths = [] levelpath_tria_idx = [] + # now, on the rotated laplace function, sample equally spaced (on midline: level_of_rotated_laplace) levelpaths contour_with_thickness = [contour_2d.copy(), np.full(contour_2d.shape[0], np.nan)] - for i in range(1, n_points + 1): - level = gf_interp[i] + for current_level in level_of_rotated_laplace[1:-1]: # levelpath starts at index zero lvlpath, lvlpath_length, tria_idx = tria.level_path( - gf, level, get_tria_idx=True + gf, current_level, get_tria_idx=True, ) levelpaths.append(lvlpath) @@ -439,10 +389,10 @@ def cc_thickness( levelpath_end = lvlpath[-1, :2] contour_with_thickness, inserted_idx_start = insert_point_with_thickness( - contour_with_thickness, levelpath_start, lvlpath_length, get_index=True + contour_with_thickness, levelpath_start, lvlpath_length, get_index=True, ) contour_with_thickness, inserted_idx_end = insert_point_with_thickness( - contour_with_thickness, levelpath_end, lvlpath_length, get_index=True + contour_with_thickness, levelpath_end, lvlpath_length, get_index=True, ) # keep track of start and end indices @@ -456,91 +406,9 @@ def cc_thickness( if inserted_idx_end >= posterior_endpoint_idx: posterior_endpoint_idx += 1 - # import matplotlib.pyplot as plt - - # fig, ax = plt.subplots(figsize=(10, 8)) - # cont = contour_with_thickness[0] - # ax.plot(cont[:,0], cont[:,1], 'k-', label='Contour', marker='o', markersize=3) - # ax.scatter(cont[:,0][anterior_endpoint_idx], cont[:,1][anterior_endpoint_idx], c='r', - # label='Anterior Endpoint', marker='o') - # ax.scatter(cont[:,0][posterior_endpoint_idx], cont[:,1][posterior_endpoint_idx], c='b', - # label='Posterior Endpoint', marker='o') - # ax.legend() - # plt.show() - - # thickness_measurement_points_top = [] - # thickness_measurement_points_bottom = [] - # for i in range(len(levelpaths)): - # thickness_measurement_points_top.append(levelpaths[i][0,:2]) - # thickness_measurement_points_bottom.append(levelpaths[i][-1,:2]) - - # thickness_measurement_points_top = np.array(thickness_measurement_points_top) - # thickness_measurement_points_bottom = np.array(thickness_measurement_points_bottom) - # thickness_measurement_points = np.concatenate([thickness_measurement_points_top, - # thickness_measurement_points_bottom], axis=0).T - - # # Create a figure with subplots - # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) - - # # Plot 1: Contour - # ax1.plot(contour_2d[:,0], -contour_2d[:,1], 'b-', linewidth=2, label='Contour') - # ax1.set_title('Corpus Callosum Contour') - # ax1.set_xlabel('X') - # ax1.set_ylabel('Y') - # ax1.axis('equal') - # ax1.invert_yaxis() - # ax1.legend() - - # # Plot 2: Thickness measurement points - # print(thickness_measurement_points.shape) - # ax2.plot(thickness_measurement_points[0, :100], -thickness_measurement_points[1, :100], 'ro', - # markersize=3, label='Thickness Points (start)') - # ax2.plot(thickness_measurement_points[0, 100:], -thickness_measurement_points[1, 100:], 'go', - # markersize=3, label='Thickness Points (end)') - # ax2.set_title('Thickness Measurement Points') - # ax2.set_xlabel('X') - # ax2.set_ylabel('Y') - # ax2.axis('equal') - # ax2.invert_yaxis() - # ax2.legend() - # plt.show() - # get curvature of path3d_resampled curvature = compute_curvature(midline_equidistant) out_curvature = np.abs(np.degrees(np.mean(curvature))) / len(curvature) - # print(f'Curvature: {out_curvature:.2f}') - # print(f'Length of midline: ', f'{midline_length:.2f}') - # print(f'Thickness: {np.mean(levelpath_lengths):.2f}') - - # import matplotlib.pyplot as plt - # import matplotlib.tri as tri - # fig, ax = plt.subplots(figsize=(5, 4)) - # mtpltlb_tria = tri.Triangulation(tria.v[:,0], tria.v[:,1], triangles=tria.t) - # triang = plt.tricontourf(mtpltlb_tria, gf, cmap='autumn', alpha=0.2) - # ax.plot(midline_equidistant[:,0], midline_equidistant[:,1], 'r-', label=f'Levelsets')#, marker='o', markersize=2) - # #ax.plot(contour_2d[:,0], contour_2d[:,1], 'k-', label='Contour', alpha=0.6) - - # for i in range(len(levelpaths)): - # if levelpaths[i] is not None: - # ax.plot(levelpaths[i][:,0], levelpaths[i][:,1], 'r-', marker='o', markersize=0) # , - # label=f'Level {levelpath_lengths[i]:.2f}' - # ax.plot(midline_equidistant[:,0], midline_equidistant[:,1], '-', label='Midline', alpha=1, - # color='darkgoldenrod')#, marker='o', markersize=2) - - # #plt.colorbar(colorscale, label='Level values') - # # plot mesh - # ax.triplot(tria.v[:,0], tria.v[:,1], tria.t, 'k-', alpha=0.2, linewidth=0.5) - # #ax.scatter(path3d_resampled[99,0], path3d_resampled[99,1], c='g', s=20) - - # ax.set_aspect('equal') - # #plt.title('Levelpath on rotated Poisson') - # plt.legend() - # # invert x axis - # ax.invert_xaxis() - # plt.tight_layout() - # plt.axis('off') - # plt.savefig(f'levelsets.png', dpi=300, bbox_inches='tight') - # plt.show() return ( midline_length, diff --git a/CorpusCallosum/utils/utils.py b/CorpusCallosum/utils/utils.py deleted file mode 100644 index 15aa0714..00000000 --- a/CorpusCallosum/utils/utils.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys - - -class HiddenPrints: - """Context manager for suppressing stdout output. - - Temporarily redirects stdout to os.devnull to hide any print statements - within the context. - - Examples - -------- - >>> with HiddenPrints(): - ... print("This will not be visible") - >>> print("This will be visible") - """ - - def __enter__(self) -> None: - """Enter the context manager. - - Returns - ------- - None - """ - self._original_stdout = sys.stdout - sys.stdout = open(os.devnull, "w") - - def __exit__(self, exc_type: type | None, exc_val: Exception | None, - exc_tb: type | None) -> None: - """Exit the context manager. - - Parameters - ---------- - exc_type : type or None - Type of the exception that occurred, if any - exc_val : Exception or None - Exception instance that occurred, if any - exc_tb : type or None - Traceback of the exception that occurred, if any - - Returns - ------- - None - """ - sys.stdout.close() - sys.stdout = self._original_stdout \ No newline at end of file diff --git a/CorpusCallosum/visualization/visualization.py b/CorpusCallosum/visualization/visualization.py index 0d679373..539ec462 100644 --- a/CorpusCallosum/visualization/visualization.py +++ b/CorpusCallosum/visualization/visualization.py @@ -136,30 +136,28 @@ def plot_contours( vox_size: float | None = None, title: str = "", ) -> None: - """Plot contours and subdivisions of the corpus callosum. + """Creates a figure of the countours (shape) and the subdivisions of the corpus callosum. Parameters ---------- transformed : np.ndarray Transformed image data split_contours : list[np.ndarray], optional - List of contour arrays for each subdivision, by default None + List of contour arrays for each subdivision (ignore countours on None) midline_equidistant : np.ndarray, optional - Midline points at equidistant spacing, by default None + Midline points at equidistant spacing (ignore midline on None). levelpaths : list[np.ndarray], optional - List of level paths for visualization, by default None + List of level paths for visualization (ignore level paths on None). output_path : str or Path, optional - Path to save the plot, by default None + Path to save the plot (do not save on None). ac_coords : np.ndarray, optional - AC coordinates for visualization, by default None + AC coordinates for visualization (ignore AC on None). pc_coords : np.ndarray, optional - PC coordinates for visualization, by default None + PC coordinates for visualization (ignore PC on None). vox_size : float, optional - Voxel size for scaling, by default None - title : str, optional - Title for the plot, by default "" - debug : bool, optional - Whether to show debug information, by default False + Voxel size for scaling + title : str, default="" + Title for the plot. Notes ----- @@ -175,49 +173,46 @@ def plot_contours( if levelpaths: levelpaths = np.stack(levelpaths, axis=0) / vox_size - NO_PLOTS = 1 + int(split_contours is not None) + has_first_plot = bool(split_contours) or bool(ac_coords) or bool(pc_coords) + num_plots = 1 + int(has_first_plot) - _, ax = plt.subplots(1, NO_PLOTS, sharex=True, sharey=True, figsize=(15, 10)) + _, ax = plt.subplots(1, num_plots, sharex=True, sharey=True, figsize=(15, 10)) # NOTE: For all plots imshow shows y inverted current_plot = 0 - - if split_contours is not None: + if has_first_plot: ax[current_plot].imshow(transformed[transformed.shape[0] // 2], cmap="gray") - # ax[0].imshow(cc_mask, cmap='autumn') ax[current_plot].set_title(title) + if split_contours: for i, this_contour in enumerate(split_contours): ax[current_plot].fill(this_contour[0, :], -this_contour[1, :], color="steelblue", alpha=0.25) kwargs = {"color": "mediumblue", "linewidth": 0.7, "linestyle": "solid" if i != 0 else "dotted"} ax[current_plot].plot(this_contour[0, :], -this_contour[1, :], **kwargs) + if ac_coords: ax[current_plot].scatter(ac_coords[1], ac_coords[0], color="red", marker="x") + if pc_coords: ax[current_plot].scatter(pc_coords[1], pc_coords[0], color="blue", marker="x") - current_plot += 1 + current_plot += int(has_first_plot) reference_contour = split_contours[0] - ax[current_plot].imshow(transformed[transformed.shape[0] // 2], cmap="gray") - # ax[2].imshow(cc_mask, cmap='autumn') for this_path in levelpaths: ax[current_plot].plot(this_path[:, 0], -this_path[:, 1], color="brown", linewidth=0.8) ax[current_plot].set_title("Midline & Levelpaths") ax[current_plot].plot(midline_equidistant[:, 0], -midline_equidistant[:, 1], color="red") ax[current_plot].plot(reference_contour[0, :], -reference_contour[1, :], color="red", linewidth=0.5) + padding = 30 for a in ax.flatten(): a.set_aspect("equal", adjustable="box") a.axis("off") - - # get bounding box of contours - padding = 30 - ax[0].set_xlim(reference_contour[0, :].min() - padding, reference_contour[0, :].max() + padding) - ax[0].set_ylim((-reference_contour[1, :]).max() + padding, (-reference_contour[1, :]).min() - padding) + # get bounding box of contours + a.set_xlim(reference_contour[0, :].min() - padding, reference_contour[0, :].max() + padding) + a.set_ylim((-reference_contour[1, :]).max() + padding, (-reference_contour[1, :]).min() - padding) Path(output_path).parent.mkdir(parents=True, exist_ok=True) - plt.savefig(output_path, dpi=300, bbox_inches="tight") - # plt.show() def plot_midplane(grid_orig: np.ndarray, orig: np.ndarray) -> None: diff --git a/recon_surf/recon-surf.sh b/recon_surf/recon-surf.sh index f0199664..5727c850 100755 --- a/recon_surf/recon-surf.sh +++ b/recon_surf/recon-surf.sh @@ -619,41 +619,52 @@ fi # ============================= CC SEGMENTATION ============================================ - - -{ - echo " " - echo "============ Creating and adding CC Segmentation ============" - echo " " -} | tee -a "$LF" -# create aseg.auto including corpus callosum segmentation and 46 sec, requires norm.mgz -# Note: if original input segmentation already contains CC, this will exit with ERROR -# in the future maybe check and skip this step (and next) -cmd="$python ${binpath}../CorpusCallosum/fastsurfer_cc.py --subject_dir $SUBJECTS_DIR/$subject --verbose" -RunIt "$cmd" "$LF" -# add CC into aparc.DKTatlas+aseg.deep (not sure if this is really needed) -cmd="$python ${FASTSURFER_HOME}/CorpusCallosum/paint_cc_into_pred.py -in_cc $mdir/callosum_seg_aseg_space.mgz -in_pred $asegdkt_segfile -out $mdir/aparc.DKTatlas+aseg.deep.withCC.mgz" -RunIt "$cmd" "$LF" -# add CC into aseg.auto.mgz as mri_cc did before. Not sure where this is used. -cmd="$python ${FASTSURFER_HOME}/CorpusCallosum/paint_cc_into_pred.py -in_cc $mdir/callosum_seg_aseg_space.mgz -in_pred $mdir/$aseg_nocc -out $mdir/aseg.auto.mgz" -RunIt "$cmd" "$LF" - - -# { -# echo " " -# echo "============ Creating and adding CC Segmentation (mri_cc) ============" -# echo " " -# } | tee -a "$LF" -# # create aseg.auto including corpus callosum segmentation and 46 sec, requires norm.mgz -# # Note: if original input segmentation already contains CC, this will exit with ERROR -# # in the future maybe check and skip this step (and next) -# cmd="mri_cc -aseg $aseg_nocc -o aseg.auto.mgz -lta $mdir/transforms/cc_up.lta $subject" -# RunIt "$cmd" "$LF" -# # add CC into aparc.DKTatlas+aseg.deep (not sure if this is really needed) -# cmd="$python ${binpath}../CorpusCallosum/paint_cc_into_pred.py -in_cc $mdir/aseg.auto.mgz -in_pred $asegdkt_segfile -out $mdir/aparc.DKTatlas+aseg.deep.withCC.mgz" -# RunIt "$cmd" "$LF" - - +# here, we are only generating the "necessary" files for the pipeline to recon-surf pipeline to +# complete, people should use the seg pipeline to get extended results. +callosum_seg="callosum_seg_aseg_space.mgz" +callosum_seg_manedit="$(add_file_suffix "$callosum_seg" "manedit")" +aseg_auto="aseg.auto.mgz" +CorpusCallosumDir="$FASTSURFER_HOME/CorpusCallosum" +updated_cc_seg=0 +if [[ ! -e "$mdir/$aseg_auto" ]] || [[ ! -e "$mdir/$callosum_seg" ]] || [[ "$edits" == 1 ]] +then + { + echo " " + echo "============ Creating and adding CC Segmentation ============" + echo " " + } | tee -a "$LF" +fi +# here, in edits mode we also check, if the corpus callosum should be updated based on an updated aseg.nocc +if [[ ! -e "$mdir/$callosum_seg" ]] || \ + { [[ "$edits" == 1 ]] && [[ "$(date -r "$mdir/$aseg_nocc" "+%s")" -gt "$(date -r "$mdir/$callosum_seg" "+%s")" ]] ; } +then + { + echo "Segmenting the corpus callosum, so mri/$aseg_nocc exists. If you are interested in detailed" + echo " and extended analysis and statistics of the Corpus Callosum, use the corpus callosum pipeline" + echo " of the segmentation pipeline (in run_fastsurfer.sh, i.e. run without --no_cc)." + } + updated_cc_seg=1 + # create aseg.auto including corpus callosum segmentation and 46 sec, requires norm.mgz + # Note: if original input segmentation already contains CC, this will exit with ERROR + # in the future maybe check and skip this step (and next) + cmda=($python "$CorpusCallosumDir/fastsurfer_cc.py" --sd "$SUBJECTS_DIR" --sid "$subject" + "--aseg_name" "$mdir/$aseg_nocc" "--segmentation_in_orig" "$mdir/$callosum_seg" + --threads "$threads" + # qc_snapshots are only defined by the seg_only pipeline + # limit the processing things to do here + --slice_selection "middle" --cc_measures "none" --cc_mid_measures "none" --surf "none" + --thickness_overlay "none") + run_it "$LF" "${cmda[@]}" +fi +# do not move below statement up, fastsurfer_cc.py uses the $callosum_seg variable +if [[ "$edits" == 1 ]] && [[ -e "$mdir/$callosum_seg_manedit" ]] ; then callosum_seg="$callosum_seg_manedit" ; fi +cmd_paint_cc_into_pred=($python "$CorpusCallosumDir/paint_cc_into_pred.py" -in_cc "$mdir/$callosum_seg" -in_pred) +if [[ ! -e "$mdir/$aseg_auto" ]] || [[ "$updated_cc_seg" == 1 ]] +then + # add CC into aseg.auto.mgz as mri_cc did before. Not sure where this is used. + cmda=("${cmd_paint_cc_into_pred[@]}" "$mdir/$aseg_nocc" "-out" "$mdir/$aseg_auto") + run_it "$LF" "${cmda[@]}" +fi # ============================= FILLED ===================================================== { diff --git a/run_fastsurfer.sh b/run_fastsurfer.sh index 83505e91..367c2a31 100755 --- a/run_fastsurfer.sh +++ b/run_fastsurfer.sh @@ -32,6 +32,7 @@ fastsurfercnndir="$FASTSURFER_HOME/FastSurferCNN" cerebnetdir="$FASTSURFER_HOME/CerebNet" hypvinndir="$FASTSURFER_HOME/HypVINN" reconsurfdir="$FASTSURFER_HOME/recon_surf" +CorpusCallosumDir="$FASTSURFER_HOME/CorpusCallosum" # Regular flags defaults subject="" @@ -49,6 +50,7 @@ hypo_segfile="" hypo_statsfile="" hypvinn_flags=() hypvinn_regmode="coreg" +cc_flags=() conformed_name="" conformed_name_t2="" norm_name="" @@ -70,6 +72,7 @@ native_image="false" run_asegdkt_module="1" run_cereb_module="1" run_hypvinn_module="1" +run_cc_module="1" threads_seg="1" threads_surf="1" # python3.10 -s excludes user-directory package inclusion @@ -213,6 +216,11 @@ SEGMENTATION PIPELINE: --no_biasfield Deactivate the calculation of partial volume-corrected statistics. + CORPUS CALLOSUM MODULE: + --no_cc Skip the segmentation and analysis of the corpus callosum. + --qc_snap Create QC snapshots in \$SUBJECTS_DIR/\$sid/qc_snapshots + to simplify the QC process. + HYPOTHALAMUS MODULE (HypVINN): --no_hypothal Skip the hypothalamus segmentation. --no_biasfield This option implies --no_hypothal, as the hypothalamus @@ -458,6 +466,10 @@ case $key in --mask_name) mask_name="$1" ; warn_seg_only+=("$key" "$1") ; warn_base+=("$key" "$1") ; shift ;; --merged_segfile) merged_segfile="$1" ; shift ;; + # corupus callosum module options + #============================================================= + --no_cc) run_cc_module="0" ;; + # cereb module options #============================================================= --no_cereb) run_cereb_module="0" ;; @@ -480,7 +492,7 @@ case $key in ;; # several options that set a variable - --qc_snap) hypvinn_flags+=(--qc_snap) ;; + --qc_snap) hypvinn_flags+=(--qc_snap) ; cc_flags+=("--qc_output_dir" "qc_snapshots") ;; ############################################################## # surf-pipeline options @@ -588,6 +600,8 @@ fi if [[ -z "$merged_segfile" ]] ; then merged_segfile="$subject_dir/mri/fastsurfer.merged.mgz" ; fi if [[ -z "$asegdkt_segfile" ]] ; then asegdkt_segfile="$subject_dir/mri/aparc.DKTatlas+aseg.deep.mgz" ; fi if [[ -z "$aseg_segfile" ]] ; then aseg_segfile="$subject_dir/mri/aseg.auto_noCCseg.mgz"; fi +if [[ -z "$aseg_auto_segfile" ]] ; then aseg_auto_segfile="$subject_dir/mri/aseg.auto.mgz"; fi +if [[ -z "$callosum_seg" ]] ; then callosum_seg="$subject_dir/mri/callosum.CC.orig.mgz"; fi if [[ -z "$asegdkt_statsfile" ]] ; then asegdkt_statsfile="$subject_dir/stats/aseg+DKT.stats" ; fi if [[ -z "$asegdkt_vinn_statsfile" ]] ; then asegdkt_vinn_statsfile="$subject_dir/stats/aseg+DKT.VINN.stats" ; fi if [[ -z "$aseg_vinn_statsfile" ]] ; then aseg_vinn_statsfile="$subject_dir/stats/aseg.VINN.stats" ; fi @@ -708,6 +722,18 @@ then fi fi +if [[ "$run_seg_pipeline" == "1" ]] && { [[ "$run_asegdkt_module" == "0" ]] && [[ "$run_cc_module" == "1" ]]; } +then + if [[ ! -f "$asegdkt_segfile" ]] + then + echo "ERROR: To run the corpus callosum module but no asegdkt, the aseg segmentation must already exist." + echo " You passed --no_asegdkt but the asegdkt segmentation ($asegdkt_segfile) could not be found." + echo " If the segmentation is not saved in the default location ($asegdkt_segfile_default)," + echo " specify the absolute path and name via --asegdkt_segfile" + exit 1 + fi +fi + if [[ "$run_surf_pipeline" == "1" ]] && [[ "$native_image" != "false" ]] then echo "ERROR: The surface pipeline is not compatible with the options --native_image or " @@ -1078,6 +1104,88 @@ then fi fi + if [[ "$run_cc_module" ]] + then + # ============================= CC SEGMENTATION ============================================ + + # generate file names of for the analysis + asegdkt_withcc_segfile="$(add_file_suffix "$asegdkt_segfile" "withCC")" + asegdkt_withcc_vinn_statsfile="$(add_file_suffix "$asegdkt_vinn_statsfile" "withCC")" + aseg_auto_statsfile="$(add_file_suffix "$aseg_auto_statsfile" "withCC")" + # note: callosum manedit currently only affects inpainting and not internal FastSurferCC processing (surfaces etc) + callosum_seg_manedit="$(add_file_suffix "$callosum_seg" "manedit")" + # generate callosum segmentation, mesh, shape and downstream measure files + cmd=($python "$CorpusCallosumDir/fastsurfer_cc.py" --sd "$sd" --sid "$subject" --threads "$threads_seg" + "--aseg_name" "$asegdkt_segfile" "--segmentation_in_orig" "$callosum_seg" "${cc_flags[@]}") + { + echo_quoted "${cmd[@]}" + "${cmd[@]}" + if [[ "${PIPESTATUS[0]}" != 0 ]] ; then echo "ERROR: FastSurferCC corpus callosum analysis failed!" ; exit 1 ; fi + if [[ "$edits" == 1 ]] && [[ -f "$callosum_seg_manedit" ]] ; then callosum_seg="$callosum_seg_manedit" ; fi + + # add CC into aparc.DKTatlas+aseg.deep.mgz and aseg.auto.mgz as mri_cc did before. + cmd=($python "$CorpusCallosumDir/paint_cc_into_pred.py" -in_cc "$callosum_seg" -in_pred "$asegdkt_segfile" + "-out" "$asegdkt_withcc_segfile" "-aseg" "$aseg_auto_segfile") + echo_quoted "${cmd[@]}" + "${cmd[@]}" + if [[ "${PIPESTATUS[0]}" != 0 ]] ; then echo "ERROR: asegdkt cc inpainting failed!" ; exit 1 ; fi + + if [[ "$run_biasfield" == 1 ]] + then + # TODO: decide how to measure the size of the white matter, maybe import measures from previous? + # TODO: decide whether to include the fornix PV-corrected volume + # PV list here and not asegdkt_segfile: 192 Fornix + cmd=($python "${fastsurfercnndir}/segstats.py" --segfile "$asegdkt_withcc_segfile" --normfile "$norm_name" + --lut "$fastsurfercnndir/config/FreeSurferColorLUT.txt" --sd "${sd}" --sid "${subject}" + --ids 2 4 5 7 8 10 11 12 13 14 15 16 17 18 24 26 28 31 41 43 44 46 47 49 50 51 52 53 + 54 58 60 63 77 192 251 252 253 254 255 + 1002 1003 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 + 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1034 1035 + 2002 2003 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 + 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2034 2035 + --threads "$threads_seg" --empty --excludeid 0 + --segstatsfile "$aseg_withcc_vinn_statsfile" + measures + # the following measures are unaffected by CC and do not need to be recomputed + --import Mask --file "$asegdkt_vinn_statsfile" + # recompute the measures based on "better" volumes: + --compute BrainSeg BrainSegNotVent SupraTentorial SupraTentorialNotVent + SubCortGray rhCerebralWhiteMatter lhCerebralWhiteMatter CerebralWhiteMatter + ) + echo_quoted "${cmd[@]}" + "${cmd[@]}" + if [[ "${PIPESTATUS[0]}" != 0 ]] ; then + echo "ERROR: asegdkt statsfile ($asegdkt_withcc_segfile) generation failed!" ; exit 1 + # this will only terminate the subshell + fi + fi + } 2>&1 | tee -a "$seg_log" + code="${PIPESTATUS[0]}" + if [[ "$code" != 0 ]]; then exit 1; fi # forward subshell exit to main script + + if [[ "$run_biasfield" == 1 ]] + then + { + # TODO: decide how to measure the size of the white matter + # TODO: decide whether to include the fornix PV-corrected volume + # PV list here and not asegdkt_segfile: 192 Fornix + cmd=($python "${fastsurfercnndir}/segstats.py" --segfile "$aseg_auto_segfile" --normfile "$norm_name" + --lut "$fastsurfercnndir/config/FreeSurferColorLUT.txt" --sd "${sd}" --sid "${subject}" + --threads "$threads_seg" --empty --excludeid 0 + --ids 2 4 3 5 7 8 10 11 12 13 14 15 16 17 18 24 26 28 31 41 42 43 44 46 47 49 50 51 52 53 54 58 60 63 77 + 192 251 252 253 254 255 + --segstatsfile "$aseg_auto_statsfile" + measures --import "all" --file "$asegdkt_withcc_vinn_statsfile" + ) + echo_quoted "${cmd[@]}" + "${cmd[@]}" 2>&1 + if [[ "${PIPESTATUS[0]}" != 0 ]] ; then echo "ERROR: aseg statsfile ($aseg_auto_segfile) failed!" ; exit 1 ; fi + } | tee -a "$seg_log" + if [[ "${PIPESTATUS[0]}" != 0 ]] ; then exit 1; fi # forward subshell exit to main script + + fi + fi + if [[ "$run_cereb_module" == "1" ]] then if [[ "$run_biasfield" == "1" ]] From f4652ff2c23cee92b4da06b36b36b3b451428770 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Tue, 25 Nov 2025 18:54:23 +0100 Subject: [PATCH 30/68] Fix doc build errors and ruff optimization codes Fix typos in typing and docstrings --- CorpusCallosum/shape/cc_subsegment_contour.py | 6 ++---- CorpusCallosum/visualization/visualization.py | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/CorpusCallosum/shape/cc_subsegment_contour.py b/CorpusCallosum/shape/cc_subsegment_contour.py index 16ab7dc2..4055a026 100644 --- a/CorpusCallosum/shape/cc_subsegment_contour.py +++ b/CorpusCallosum/shape/cc_subsegment_contour.py @@ -146,8 +146,7 @@ def subsegment_midline_orthogonal(midline, area_weights, contour, plot=True, ax= edge_ortho_vectors = np.column_stack((-edge_directions[:, 1], edge_directions[:, 0])) edge_ortho_vectors = edge_ortho_vectors / np.linalg.norm(edge_ortho_vectors, axis=1)[:, None] - split_contours = [] - split_contours.append(contour) + split_contours = [contour] for pt_idx, split_point in enumerate(split_points): intersections = [] @@ -277,8 +276,7 @@ def subsegment_midline_orthogonal(midline, area_weights, contour, plot=True, ax= # convert area_weights into fraction of total line length # e.g. area_weights=[1/6, 1/2, 2/3, 3/4] to ['1/6', '2/3', ...] # cumulative difference - area_weights_diff = [] - area_weights_diff.append(area_weights[0]) + area_weights_diff = [area_weights[0]] for i in range(1, len(area_weights)): area_weights_diff.append(area_weights[i] - area_weights[i - 1]) area_weights_diff.append(1 - area_weights[-1]) diff --git a/CorpusCallosum/visualization/visualization.py b/CorpusCallosum/visualization/visualization.py index 539ec462..398fde3d 100644 --- a/CorpusCallosum/visualization/visualization.py +++ b/CorpusCallosum/visualization/visualization.py @@ -136,14 +136,14 @@ def plot_contours( vox_size: float | None = None, title: str = "", ) -> None: - """Creates a figure of the countours (shape) and the subdivisions of the corpus callosum. + """Creates a figure of the contours (shape) and the subdivisions of the corpus callosum. Parameters ---------- transformed : np.ndarray Transformed image data split_contours : list[np.ndarray], optional - List of contour arrays for each subdivision (ignore countours on None) + List of contour arrays for each subdivision (ignore contours on None) midline_equidistant : np.ndarray, optional Midline points at equidistant spacing (ignore midline on None). levelpaths : list[np.ndarray], optional From ebee7add74bd533f9276a5ddbd76f21c8b133c5b Mon Sep 17 00:00:00 2001 From: ClePol Date: Tue, 25 Nov 2025 16:24:15 +0100 Subject: [PATCH 31/68] updated helptexts & formatting --- CorpusCallosum/cc_visualization.py | 59 +++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 10 deletions(-) diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index 8d9aa856..ed4a7c5d 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -14,25 +14,53 @@ def make_parser() -> argparse.ArgumentParser: """Create a command line parser for the visualization pipeline.""" parser = argparse.ArgumentParser(description="Visualize corpus callosum from template files.") - parser.add_argument("--contours", type=str, required=False, help="Path to contours.txt file", default=None) - parser.add_argument("--thickness", type=str, required=True, help="Path to thickness_values.txt file") + parser.add_argument( + "--contours", + type=str, + required=False, + help="Path to contours.txt file if not provided, uses fsaverage template.", + metavar="CONTOURS_PATH", + default=None + ) + parser.add_argument( + "--thickness", + type=str, + required=True, + help="Path to thickness_values.txt file.", + metavar="THICKNESS_VALUES_PATH" + ) parser.add_argument( "--measurement_points", type=str, required=True, - help="Path to measurement points file containing the original vertex indices where thickness was measured", + help="Path to measurement points file containing the original vertex indices where thickness was measured.", + ) + parser.add_argument("--output_dir", + type=str, + required=True, + help="Directory for output files.", + metavar="OUTPUT_DIR" ) - parser.add_argument("--output_dir", type=str, required=True, help="Directory for output files") - parser.add_argument("--resolution", type=float, default=1.0, help="Resolution in mm for the mesh") parser.add_argument( - "--smoothing_window", type=int, default=5, help="Window size for smoothing the contour" + "--resolution", + type=float, + default=1.0, + help="Resolution in mm for the mesh.", + metavar="RESOLUTION" + ) + parser.add_argument( + "--smoothing_window", + type=int, + default=5, + help="Window size for smoothing the contour.", + metavar="SMOOTHING_WINDOW" ) parser.add_argument( "--colormap", type=str, default="red_to_yellow", choices=["red_to_blue", "blue_to_red", "red_to_yellow", "yellow_to_red"], - help="Colormap to use for thickness visualization", + help="Colormap to use for thickness visualization.", ) parser.add_argument( "--color_range", @@ -41,10 +69,21 @@ def make_parser() -> argparse.ArgumentParser: default=None, metavar=("MIN", "MAX"), required=False, - help="Specify the range for the colorbar (2 values: min max). Defaults to automatic choice.", + help="Specify the range for the colorbar (2 values: min max). Defaults to automatic choice. \ + (e.g. --color_range 0 10).", + ) + parser.add_argument( + "--legend", + type=str, + default="Thickness (mm)", + help="Legend for the colorbar.", + metavar="LEGEND") + parser.add_argument( + "--twoD", + action="store_true", + help="Generate 2D visualization instead of 3D mesh.", + metavar="TWO_D" ) - parser.add_argument("--legend", type=str, default="Thickness (mm)", help="Legend for the colorbar") - parser.add_argument("--twoD", action="store_true", help="Generate 2D visualization instead of 3D mesh") return parser From 189cfc8594f69c0d8bb99dad7d5b44180bcb5df0 Mon Sep 17 00:00:00 2001 From: ClePol Date: Wed, 26 Nov 2025 17:35:25 +0100 Subject: [PATCH 32/68] documentation and review comments --- CorpusCallosum/README.md | 375 +----------------- CorpusCallosum/cc_visualization.py | 11 +- CorpusCallosum/data/constants.py | 5 +- CorpusCallosum/data/fsaverage_cc_template.py | 27 +- .../data/generate_fsaverage_centroids.py | 46 +-- CorpusCallosum/data/read_write.py | 2 +- CorpusCallosum/fastsurfer_cc.py | 45 ++- .../localization/localization_inference.py | 25 +- CorpusCallosum/paint_cc_into_pred.py | 11 +- .../registration/mapping_helpers.py | 19 +- .../segmentation/segmentation_inference.py | 96 ++++- .../segmentation_postprocessing.py | 56 +-- CorpusCallosum/shape/cc_mesh.py | 8 +- CorpusCallosum/shape/cc_metrics.py | 56 +-- CorpusCallosum/shape/cc_postprocessing.py | 20 +- CorpusCallosum/shape/cc_subsegment_contour.py | 85 +--- CorpusCallosum/shape/cc_thickness.py | 22 +- .../transforms/localization_transforms.py | 16 +- .../transforms/segmentation_transforms.py | 22 +- CorpusCallosum/visualization/visualization.py | 40 +- README.md | 12 +- doc/api/CorpusCallosum_utils.rst | 1 - doc/api/index.rst | 5 + doc/overview/FLAGS.md | 4 +- doc/overview/OUTPUT_FILES.md | 27 +- doc/overview/index.rst | 1 + doc/scripts/advanced.rst | 2 + 27 files changed, 335 insertions(+), 704 deletions(-) diff --git a/CorpusCallosum/README.md b/CorpusCallosum/README.md index 35605674..a417dd84 100644 --- a/CorpusCallosum/README.md +++ b/CorpusCallosum/README.md @@ -3,380 +3,15 @@ A deep learning-based pipeline for automated segmentation, analysis, and shape analysis of the corpus callosum in brain MRI scans. Also segments the fornix, localizes the anterior and posterior commissure (AC and PC) and standardizes the orientation of the brain. -## Overview - -This pipeline combines localization and segmentation deep learning models to: -1. Detect AC (Anterior Commissure) and PC (Posterior Commissure) points -2. Extract and align midplane slices -3. Segment the corpus callosum -4. Perform advanced morphometry for corpus callosum, including subdivision, thickness analysis, and various shape metrics -5. Generate visualizations and measurements - +For detailed documentation, please refer to: +- [Module Overview](../doc/overview/modules/CC.md): Detailed description of the pipeline, workflow, and analysis options. +- [Output Files](../doc/overview/OUTPUT_FILES.md#corpus-callosum-module): List of output files and their descriptions. ## Quickstart ```bash python3 fastsurfer_cc.py --subject_dir /path/to/fastsurfer/output --verbose -` `` - -Gives all standard outputs. Then corpus callosum morphometry can be found at `stats/callosum.CC.midslice.json`, including 100 thickness measurements and areas of sub-segments. -Visualization will be placed in `/path/to/fastsurfer/output/qc_snapshots`. For more detailed info see the following sections. - -## Command Line Interfaces - -### Main Pipeline: `fastsurfer_cc.py` - -The main pipeline script performs the complete corpus callosum analysis workflow. - -#### Basic Usage - -```bash -# Using individual file paths -python3 fastsurfer_cc.py --in_mri /path/to/input/mri.mgz --aseg /path/to/input/aseg.mgz --output_dir /path/to/output --verbose - -# Using FastSurfer/FreeSurfer subject directory structure -python3 fastsurfer_cc.py --subject_dir /path/to/fastsurfer/output --verbose -``` - -#### Required Arguments - -Choose one of these input methods: - -**Option 1: Individual files** -- `--in_mri PATH`: Input MRI file path (FreeSurfer-conformed) -- `--aseg PATH`: Input segmentation file path -- `--output_dir PATH`: Directory for output files - -**Option 2: FastSurfer/FreeSurfer subject directory** -- `--subject_dir PATH`: Subject directory containing standard FastSurfer structure - - Automatically uses `mri/orig.mgz` and `mri/aparc.DKTatlas+aseg.deep.mgz` - - Creates standard output paths in FastSurfer structure - -#### Optional Arguments - -**General Options:** -- `--verbose`: Enable verbose output and debug plots -- `--debug_output_dir PATH`: Directory for debug outputs -- `--cpu`: Force CPU usage even when CUDA is available - -**Shape Analysis Parameters:** -- `--num_thickness_points INT`: Number of points for thickness estimation (default: 100) -- `--subdivisions FLOAT [FLOAT ...]`: List of subdivision fractions for CC subsegmentation (default: following Hofer-Frahm definition) -- `--subdivision_method {shape,vertical,angular,eigenvector}`: Method for contour subdivision (default: "shape") - - `shape`: Intercallosal subdivision perpendicular to intercallosal line - - `vertical`: Orthogonal to the most anterior and posterior points in AC/PC standardized CC contour - - `angular`: Subdivision based on equally spaced angles (Hampel et al.) - - `eigenvector`: Primary direction (same as FreeSurfer's mri_cc) -- `--contour_smoothing FLOAT`: Gaussian sigma for smoothing during contour detection (default: 1.0) -- `--slice_selection {middle,all,INT}`: Which slices to process (default: "all") - -**Custom Output Paths:** -- `--upright_volume_path PATH`: Path for upright volume output -- `--segmentation_path PATH`: Path for segmentation output -- `--postproc_results_path PATH`: Path for postprocessing results -- `--cc_markers_path PATH`: Path for CC markers output -- `--upright_lta_path PATH`: Path for upright LTA transform -- `--orient_volume_lta_path PATH`: Path for orientation volume LTA transform -- `--orig_space_segmentation_path PATH`: Path for segmentation in original space -- `--qc_image_path PATH`: Path for QC visualization image - -**Template Saving:** -- `--save_template PATH`: Directory path to save contours.txt and thickness_values.txt files - -#### Examples - -```bash -# Basic analysis with FreeSurfer subject directory -python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 --verbose - -# Custom shape analysis parameters -python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ - --num_thickness_points 150 \ - --subdivisions 0.2 0.4 0.6 0.8 \ - --subdivision_method angular \ - --contour_smoothing 1.5 - -# Process all slices instead of just middle slice -python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ - --slice_selection all - -# Save template files for visualization -python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ - --save_template /data/templates/sub001 -``` - -## Outputs - -The pipeline produces the following outputs in the specified output directory: - -### Main Pipeline Outputs - -**Analysis Results:** -- `stats/callosum.CC.midslice.json`: Contains detected landmarks and measurements for the middle slice -- `stats/callosum.CC.all_slices.json`: Enhanced postprocessing results with per-slice analysis - -**Transformation Matrices:** -- `mri/transforms/cc_up.lta`: Transformation from original to upright space (aligned to fsaverage, CC midslice at the center) -- `mri/transforms/orient_volume.lta`: Transformation a CC, AC & PC standardized space. The CC is at the center and AC & PC on the coordinate line, standardizing the head orientation. - -**Image Volumes:** -- `mri/callosum_seg_upright.mgz`: Corpus callosum segmentation in upright space (aligned to fsaverage, matching cc_up.lta) -- `mri/callosum_seg_aseg_space.mgz`: Corpus callosum segmentation in conformed image orientation (aligned to orig.mgz and other segmentations) -- `mri/callosum_seg_soft.mgz`: Corpus callosum soft labels (segmentation probabilities, upright space) -- `mri/fornix_seg_soft.mgz`: Fornix soft labels (segmentation probabilities, upright space) -- `mri/background_seg_soft.mgz`: Background soft labels (segmentation probabilities, upright space) - - -**Quality Control and Visualizations:** -- `qc_snapshots/callosum.png`: Debug visualization of corpus callosum contours and thickness measurements -- `qc_snapshots/callosum_thickness.png`: 3D thickness visualization (when using `--slice_selection all`) -- `qc_snapshots/corpus_callosum.html`: Interactive 3D mesh visualization (when using `--slice_selection all`) - - -**Surface Files (only provided when using `--slice_selection all`):** -- `surf/callosum.surf`: FreeSurfer surface format for integration with FreeSurfer tools (e.g. freeview) -- `surf/callosum.thickness.w`: FreeSurfer overlay file containing thickness values -- `surf/callosum_mesh.vtk`: VTK format mesh file for 3D visualization - -**Template Files (when --save_template is used):** - -- `contours.txt`: Corpus callosum contour coordinates for visualization -- `thickness_values.txt`: Thickness measurements at each contour point -- `measurement_points.txt`: Original vertex indices where thickness was measured - -## JSON Output Structure - -The pipeline generates two main JSON files with detailed measurements and analysis results: - -### `stats/callosum.CC.midslice.json` (Middle Slice Analysis) - -This file contains measurements from the middle sagittal slice and includes: - -**Shape Measurements (single values):** -- `total_area`: Total corpus callosum area (mm²) -- `total_perimeter`: Total perimeter length (mm) -- `circularity`: Shape circularity measure (4π × area / perimeter²) -- `cc_index`: Corpus callosum shape index (length/width ratio) -- `midline_length`: Length along the corpus callosum midline (mm) -- `curvature`: Average curve of the midline (degrees), measured by angle between it's sub-segements - -**Subdivisions** -- `areas`: Areas of CC using an improved Hofer-Frahm sub-division method (mm²). This gives more consistent sub-segemnts while preserving the original ratios. - -**Thickness Analysis:** -- `thickness`: Average corpus callosum thickness (mm) -- `thickness_profile`: Thickness profile (mm) of the corpus callosum slice (100 thickness values by default, listed from anterior to posterior CC ends) - - -**Volume Measurements (when multiple slices processed):** -- `cc_5mm_volume`: Total CC volume within 5mm slab using voxel counting (mm³) -- `cc_5mm_volume_pv_corrected`: Volume with partial volume correction using CC contours (mm³) - -**Anatomical Landmarks:** -- `ac_center`: Anterior commissure coordinates in original image space -- `pc_center`: Posterior commissure coordinates in original image space -- `ac_center_oriented_volume`: AC coordinates in standardized space (orient_volume.lta) -- `pc_center_oriented_volume`: PC coordinates in standardized space (orient_volume.lta) -- `ac_center_upright`: AC coordinates in upright space (cc_up.lta) -- `pc_center_upright`: PC coordinates in upright space (cc_up.lta) - -### `stats/callosum.CC.all_slices.json` (Multi-Slice Analysis) - -This file contains comprehensive per-slice analysis when using `--slice_selection all`: - -**Global Parameters:** -- `slices_in_segmentation`: Total number of slices in the segmentation volume -- `voxel_size`: Voxel dimensions [x, y, z] in mm -- `subdivision_method`: Method used for anatomical subdivision -- `num_thickness_points`: Number of points used for thickness estimation -- `subdivision_ratios`: Subdivision fractions used for regional analysis -- `contour_smoothing`: Gaussian sigma used for contour smoothing -- `slice_selection`: Slice selection mode used - -**Per-Slice Data (`slices` array):** - -Each slice entry contains the shape measurements, thickness analysis and sub-divisions as described above. - - - - -## Visualization: `cc_visualization.py` - -Creates advanced visualizations of corpus callosum from template files generated by the main pipeline. -Useful for visualization of analysis results. - -#### Basic Usage - -```bash -# Using contours file -python3 cc_visualization.py --contours /path/to/contours.txt \ - --thickness /path/to/thickness_values.txt \ - --measurement_points /path/to/measurement_points.txt \ - --output_dir /path/to/output - -# Using fsaverage template (no contours file) -python3 cc_visualization.py \ - --thickness /path/to/thickness_values.txt \ - --measurement_points /path/to/measurement_points.txt \ - --output_dir /path/to/output -``` - -#### Required Arguments - -- `--thickness PATH`: Path to thickness_values.txt file -- `--measurement_points PATH`: Path to measurement points file containing original vertex indices -- `--output_dir PATH`: Directory for output files - -#### Optional Arguments - -**Input:** -- `--contours PATH`: Path to contours.txt file (if not provided, uses fsaverage template) - -**Mesh Parameters:** -- `--resolution FLOAT`: Resolution in mm for the mesh (default: 1.0) -- `--smooth_iterations INT`: Number of smoothing iterations to apply to the mesh (default: 1) - -**Visualization Options:** -- `--colormap {red_to_blue,blue_to_red,red_to_yellow,yellow_to_red}`: Colormap for thickness visualization (default: "red_to_yellow") -- `--color_range MIN MAX`: Optional fixed range for the colorbar -- `--legend STRING`: Legend for the colorbar (default: "Thickness (mm)") -- `--twoD`: Generate 2D visualization instead of 3D mesh - -#### Colormap Options - -- `red_to_blue`: Red → Orange → Grey → Light Blue → Blue -- `blue_to_red`: Blue → Light Blue → Grey → Orange → Red -- `red_to_yellow`: Red → Yellow → Light Blue → Blue -- `yellow_to_red`: Yellow → Light Blue → Blue → Red - -#### Examples - -```bash -# Basic 3D mesh visualization -python3 cc_visualization.py \ - --thickness /data/templates/sub001/thickness_values.txt \ - --measurement_points /data/templates/sub001/measurement_points.txt \ - --output_dir /data/visualizations/sub001 - -# 2D visualization with custom colormap -python3 cc_visualization.py \ - --thickness /data/templates/sub001/thickness_values.txt \ - --measurement_points /data/templates/sub001/measurement_points.txt \ - --output_dir /data/visualizations/sub001 \ - --twoD \ - --colormap blue_to_red -``` - -## Analysis and Visualization Workflow - -The pipeline supports different analysis modes that determine the type of template data generated and corresponding visualization options: - -### 3D Analysis and Visualization - -When running the main pipeline with `--slice_selection all` and `--save_template`, a complete 3D template is generated: - -```bash -# Generate 3D template data -python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ - --slice_selection all \ - --save_template /data/templates/sub001 -``` - -This creates: -- `contours.txt`: Multi-slice contour data for 3D reconstruction -- `thickness_values.txt`: Thickness measurements across all slices -- `measurement_points.txt`: 3D vertex indices for thickness measurements - -The 3D template can then be visualized using the standard 3D mesh options: - -```bash -# Create 3D mesh visualization -python3 cc_visualization.py \ - --contours /data/templates/sub001/contours.txt \ - --thickness /data/templates/sub001/thickness_values.txt \ - --measurement_points /data/templates/sub001/measurement_points.txt \ - --output_dir /data/visualizations/sub001 ``` -**3D Analysis Benefits:** -- Generates complete surface meshes (VTK, FreeSurfer formats) -- Enables volumetric thickness analysis -- Supports advanced 3D visualizations with proper surface topology -- Creates FreeSurfer-compatible overlay files for integration with other tools - -### 2D Analysis and Visualization - -When using `--slice_selection middle` or a specific slice number with `--save_template`: - -```bash -# Generate 2D template data (middle slice) -python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ - --slice_selection middle \ - --save_template /data/templates/sub001 - -# Or specific slice -python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ - --slice_selection 5 \ - --save_template /data/templates/sub001 -``` - -This creates template data for a single slice, which should be visualized in 2D mode: - -```bash -# Create 2D visualization -python3 cc_visualization.py \ - --thickness /data/templates/sub001/thickness_values.txt \ - --measurement_points /data/templates/sub001/measurement_points.txt \ - --output_dir /data/visualizations/sub001 \ - --twoD -``` - -**2D Analysis Benefits:** -- Faster processing for single-slice analysis -- 2D visualization is most suitable for displaying downstream statistics - -### Surface Generation Requirements - -**Important:** Complete surface files (VTK, FreeSurfer surface formats, overlay files) are only generated when using `--slice_selection all`. Single-slice analysis cannot produce proper 3D surface topology and will not generate these files. - -**3D Surface Outputs (only with `--slice_selection all`):** -- `cc_mesh.vtk`: Complete 3D surface mesh -- `cc_mesh.fssurf`: FreeSurfer surface format -- `cc_mesh_overlay.curv`: Thickness overlay for FreeSurfer visualization - -**2D Outputs (any slice selection):** -- `cc_mesh_snap.png`: 2D visualization or 3D mesh snapshot -- Standard analysis JSON files with measurements - -### Choosing Analysis Mode - -**Use 3D Analysis (`--slice_selection all`) when:** -- You need complete volumetric analysis -- Surface-based visualization is required -- Integration with FreeSurfer workflows is needed -- Comprehensive thickness mapping across the entire corpus callosum is desired - -**Use 2D Analysis (`--slice_selection middle` or specific slice) when:** -- Traditional single-slice morphometry is sufficient -- Faster processing is preferred -- Focus is on mid-sagittal cross-sectional measurements -- Compatibility with classical corpus callosum studies is needed - -**Note:** The default behavior is `--slice_selection all` for comprehensive 3D analysis. Use `--slice_selection middle` to process only the middle slice for faster, traditional 2D analysis. - - - -## Visualization Tool Outputs - -When using `cc_visualization.py`, additional outputs are generated (for advanced users). - -**3D Mode Outputs (default):** -- `cc_mesh.vtk`: VTK format mesh file for 3D visualization -- `cc_mesh.fssurf`: FreeSurfer surface format -- `cc_mesh_overlay.curv`: FreeSurfer overlay file with thickness values -- `cc_mesh.html`: Interactive 3D mesh visualization -- `cc_mesh_snap.png`: Snapshot image of the 3D mesh -- `midslice_2d.png`: 2D visualization of the middle slice - -**2D Mode Outputs (when `--twoD` is specified):** -- `cc_thickness_2d.png`: 2D contour visualization with thickness colormap \ No newline at end of file +Gives all standard outputs. Then corpus callosum morphometry can be found at `stats/callosum.CC.midslice.json`, including 100 thickness measurements and areas of sub-segments. +Visualization will be placed in `/path/to/fastsurfer/output/qc_snapshots`. diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index ed4a7c5d..d23bf07a 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -38,7 +38,13 @@ def make_parser() -> argparse.ArgumentParser: parser.add_argument("--output_dir", type=str, required=True, - help="Directory for output files.", + help="Directory for output files. Writes: \\\ + cc_mesh.html - Interactive 3D mesh visualization (HTML file) \\\ + midslice_2d.png - 2D midslice visualization of the corpus callosum \\\ + cc_mesh.vtk - VTK mesh file format \\\ + cc_mesh.fssurf - FreeSurfer surface file \\\ + cc_mesh_overlay.curv - FreeSurfer curvature overlay file \\\ + cc_mesh_snap.png - Screenshot/snapshot of the 3D mesh (requires whippersnappy>=1.3.1)", metavar="OUTPUT_DIR" ) parser.add_argument( @@ -60,7 +66,7 @@ def make_parser() -> argparse.ArgumentParser: type=str, default="red_to_yellow", choices=["red_to_blue", "blue_to_red", "red_to_yellow", "yellow_to_red"], - help="Colormap to use for thickness visualization.", + help="Colormap to use for thickness visualization, lower to higher values.", ) parser.add_argument( "--color_range", @@ -82,7 +88,6 @@ def make_parser() -> argparse.ArgumentParser: "--twoD", action="store_true", help="Generate 2D visualization instead of 3D mesh.", - metavar="TWO_D" ) return parser diff --git a/CorpusCallosum/data/constants.py b/CorpusCallosum/data/constants.py index 8087079b..77e4868d 100644 --- a/CorpusCallosum/data/constants.py +++ b/CorpusCallosum/data/constants.py @@ -22,15 +22,16 @@ FSAVERAGE_MIDDLE = 128 # Middle slice index in fsaverage space CC_LABEL = 192 # Label value for corpus callosum in segmentation FORNIX_LABEL = 250 # Label value for fornix in segmentation +THIRD_VENTRICLE_LABEL = 4 # Label value for third ventricle in segmentation SUBSEGMENT_LABELS = [251, 252, 253, 254, 255] # labels for subsegments in segmentation -STANDARD_INPUT_PATHS = { +DEFAULT_INPUT_PATHS = { "conf_name": "mri/orig.mgz", "aseg_name": "mri/aparc.DKTatlas+aseg.deep.mgz", } -STANDARD_OUTPUT_PATHS = { +DEFAULT_OUTPUT_PATHS = { ## images "upright_volume": None, # orig.mgz mapped to upright space ## segmentations diff --git a/CorpusCallosum/data/fsaverage_cc_template.py b/CorpusCallosum/data/fsaverage_cc_template.py index e0106421..b331640b 100644 --- a/CorpusCallosum/data/fsaverage_cc_template.py +++ b/CorpusCallosum/data/fsaverage_cc_template.py @@ -134,6 +134,7 @@ def load_fsaverage_cc_template() -> tuple[ outside_contour = contour_with_thickness[0].T + # make sure the CC stays in shape despite smoothing by moving endpoints outwards outside_contour[0][anterior_endpoint_idx] -= 55 outside_contour[0][posterior_endpoint_idx] += 30 @@ -144,30 +145,4 @@ def load_fsaverage_cc_template() -> tuple[ outside_contour = outside_contour_smoothed - # Plot CC contour with levelsets - - # midline_equidistant = output_dict['midline_equidistant'] - # levelpaths = output_dict['levelpaths'] - # plt.figure(figsize=(12, 8)) - - # plt.plot(outside_contour[0], outside_contour[1], 'k-', linewidth=2) - - # # Plot the midline - # if midline_equidistant is not None: - # midline_x, midline_y = zip(*midline_equidistant) - # plt.plot(midline_x, midline_y, 'r-', linewidth=2, label='Midline') - - # # Plot the level paths - # if levelpaths: - # for i, path in enumerate(levelpaths): - # path_x, path_y = path[:,0], path[:,1] - # plt.plot(path_x, path_y, 'g--', linewidth=1, alpha=0.7, label=f'Level path {i+1}' if i == 0 else "") - # plt.plot(path_x, path_y, 'gx', markersize=4, alpha=0.7) - - # plt.axis('equal') - # plt.title('Corpus Callosum Contour with Levelsets') - # plt.legend(loc='best') - # plt.grid(True, linestyle='--', alpha=0.7) - # plt.show() - return outside_contour, anterior_endpoint_idx, posterior_endpoint_idx diff --git a/CorpusCallosum/data/generate_fsaverage_centroids.py b/CorpusCallosum/data/generate_fsaverage_centroids.py index c43072f7..9b7abf74 100644 --- a/CorpusCallosum/data/generate_fsaverage_centroids.py +++ b/CorpusCallosum/data/generate_fsaverage_centroids.py @@ -76,16 +76,16 @@ def main() -> None: except KeyError as err: raise OSError("FREESURFER_HOME environment variable is not set") from err - print(f"Loading fsaverage segmentation from: {fsaverage_aseg_path}") + logger.info(f"Loading fsaverage segmentation from: {fsaverage_aseg_path}") # Load fsaverage segmentation fsaverage_nib = nib.load(fsaverage_aseg_path) # Extract centroids - print("Extracting centroids from fsaverage...") + logger.info("Extracting centroids from fsaverage...") centroids_dst = get_centroids_from_nib(fsaverage_nib) - print(f"Found {len(centroids_dst)} anatomical structures with centroids") + logger.info(f"Found {len(centroids_dst)} anatomical structures with centroids") # Convert to JSON-serializable format centroids_serializable = convert_numpy_to_json_serializable(centroids_dst) @@ -96,11 +96,11 @@ def main() -> None: with open(centroids_output_path, 'w') as f: json.dump(centroids_serializable, f, indent=2) - print(f"Fsaverage centroids saved to: {centroids_output_path}") - print(f"Centroids file size: {centroids_output_path.stat().st_size} bytes") + logger.info(f"Fsaverage centroids saved to: {centroids_output_path}") + logger.info(f"Centroids file size: {centroids_output_path.stat().st_size} bytes") # Extract and save fsaverage affine matrix and header fields - print("Extracting fsaverage affine matrix and header fields...") + logger.info("Extracting fsaverage affine matrix and header fields...") fsaverage_affine = fsaverage_nib.affine.astype(float) # Convert to float for JSON serialization # Extract header fields needed for LTA @@ -138,28 +138,28 @@ def main() -> None: with open(combined_output_path, 'w') as f: json.dump(combined_data_serializable, f, indent=2) - print(f"Fsaverage affine and header data saved to: {combined_output_path}") - print(f"Combined file size: {combined_output_path.stat().st_size} bytes") - print(f"Affine matrix shape: {fsaverage_affine.shape}") - print(f"Header dims: {dims}, delta: {delta}") + logger.info(f"Fsaverage affine and header data saved to: {combined_output_path}") + logger.info(f"Combined file size: {combined_output_path.stat().st_size} bytes") + logger.info(f"Affine matrix shape: {fsaverage_affine.shape}") + logger.info(f"Header dims: {dims}, delta: {delta}") # Print some statistics label_ids = list(centroids_dst.keys()) - print(f"Label IDs range: {min(label_ids)} to {max(label_ids)}") - print("Sample centroids:") + logger.info(f"Label IDs range: {min(label_ids)} to {max(label_ids)}") + logger.info("Sample centroids:") for label_id in sorted(label_ids)[:5]: centroid = centroids_dst[label_id] - print(f" Label {label_id}: [{centroid[0]:.2f}, {centroid[1]:.2f}, {centroid[2]:.2f}]") - - print("Fsaverage affine matrix:") - print(fsaverage_affine) - - print("Fsaverage header fields:") - print(f" dims: {dims}") - print(f" delta: {delta}") - print(f" Mdc shape: {Mdc.shape}") - print(f" Pxyz_c: {Pxyz_c}") - print("Combined data structure created successfully") + logger.info(f" Label {label_id}: [{centroid[0]:.2f}, {centroid[1]:.2f}, {centroid[2]:.2f}]") + + logger.info("Fsaverage affine matrix:") + logger.info(fsaverage_affine) + + logger.info("Fsaverage header fields:") + logger.info(f" dims: {dims}") + logger.info(f" delta: {delta}") + logger.info(f" Mdc shape: {Mdc.shape}") + logger.info(f" Pxyz_c: {Pxyz_c}") + logger.info("Combined data structure created successfully") if __name__ == "__main__": diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py index 503b5b12..abc5bf01 100644 --- a/CorpusCallosum/data/read_write.py +++ b/CorpusCallosum/data/read_write.py @@ -220,4 +220,4 @@ def load_fsaverage_data(data_path: str | Path) -> tuple[npt.NDArray[float], FSAv if affine_matrix.shape != (4, 4): raise ValueError(f"Expected 4x4 affine matrix, got shape {affine_matrix.shape}") - return affine_matrix, header_data, vox2ras_tkr \ No newline at end of file + return affine_matrix, header_data, vox2ras_tkr diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index eccdc1b6..82a81368 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -28,11 +28,12 @@ from CorpusCallosum.data.constants import ( CC_LABEL, + DEFAULT_INPUT_PATHS, + DEFAULT_OUTPUT_PATHS, FSAVERAGE_CENTROIDS_PATH, FSAVERAGE_DATA_PATH, FSAVERAGE_MIDDLE, - STANDARD_INPUT_PATHS, - STANDARD_OUTPUT_PATHS, + THIRD_VENTRICLE_LABEL, ) from CorpusCallosum.data.read_write import ( FSAverageHeader, @@ -231,45 +232,45 @@ def _slice_selection(a: str) -> SliceSelection: "--segmentation", "--seg", type=path_or_none, help="Path for corpus callosum and fornix segmentation 3D image.", - default=Path(STANDARD_OUTPUT_PATHS["segmentation"]), + default=Path(DEFAULT_OUTPUT_PATHS["segmentation"]), ) advanced.add_argument( "--cc_measures", type=path_or_none, help="Path for surface-based corpus callosum measures describing shape and volume for each image slice.", - default=Path(STANDARD_OUTPUT_PATHS["cc_measures"]), + default=Path(DEFAULT_OUTPUT_PATHS["cc_measures"]), ) advanced.add_argument( "--cc_mid_measures", type=path_or_none, help="Path for surface-based corpus callosum measures of the midslice describing CC shape and volume.", - default=STANDARD_OUTPUT_PATHS["cc_markers"], + default=DEFAULT_OUTPUT_PATHS["cc_markers"], ) advanced.add_argument( "--upright_lta", type=path_or_none, help="Path for upright LTA transform. This makes sure the midplane is at 128 in LR direction, but no nodding " "correction is applied.", - default=STANDARD_OUTPUT_PATHS["upright_lta"], + default=DEFAULT_OUTPUT_PATHS["upright_lta"], ) advanced.add_argument( "--orient_volume_lta", type=path_or_none, help="Path for orientation volume LTA transform. This makes sure the midplane is at 128 in LR direction, and " "the anterior and posterior commisures are on the coordinate line, standardizing the head orientation.", - default=STANDARD_OUTPUT_PATHS["orient_volume_lta"], + default=DEFAULT_OUTPUT_PATHS["orient_volume_lta"], ) advanced.add_argument( "--segmentation_in_orig", type=path_or_none, help="Path for corpus callosum and fornix segmentation in the input MRI space.", - default=STANDARD_OUTPUT_PATHS["segmentation_in_orig"], + default=DEFAULT_OUTPUT_PATHS["segmentation_in_orig"], ) advanced.add_argument( "--qc_image", type=ReplaceQCOutputDir, help="Path for QC visualization image (if it starts with {qc_output_dir}, that is replace by --qc_output_dir).", - default=STANDARD_OUTPUT_PATHS["qc_image"], + default=DEFAULT_OUTPUT_PATHS["qc_image"], ) advanced.add_argument( "--save_template_dir", @@ -282,20 +283,20 @@ def _slice_selection(a: str) -> SliceSelection: "--thickness_image", type=ReplaceQCOutputDir, help="Path for thickness image (if it starts with {qc_output_dir}, that is replace by --qc_output_dir).", - default=STANDARD_OUTPUT_PATHS["thickness_image"], + default=DEFAULT_OUTPUT_PATHS["thickness_image"], ) advanced.add_argument( "--surf", dest="cc_surf", type=path_or_none, help="Path for surf file.", - default=STANDARD_OUTPUT_PATHS["cc_surf"], + default=DEFAULT_OUTPUT_PATHS["cc_surf"], ) advanced.add_argument( "--thickness_overlay", type=path_or_none, help="Path for corpus callosum thickness overlay file.", - default=STANDARD_OUTPUT_PATHS["cc_thickness_overlay"], + default=DEFAULT_OUTPUT_PATHS["cc_thickness_overlay"], ) advanced.add_argument( "--cc_interactive_html", "--cc_html", @@ -303,33 +304,33 @@ def _slice_selection(a: str) -> SliceSelection: type=ReplaceQCOutputDir, help="Path to the corpus callosum interactive 3D visualization HTML file (if it starts with {qc_output_dir}, " "that is replace by --qc_output_dir).", - default=STANDARD_OUTPUT_PATHS["cc_html"], + default=DEFAULT_OUTPUT_PATHS["cc_html"], ) advanced.add_argument( "--cc_surf_vtk", type=path_or_none, - help=f"Path for vtk file, showing the CC 3D mesh. Example: {STANDARD_OUTPUT_PATHS['cc_surf_vtk']}.", + help=f"Path for vtk file, showing the CC 3D mesh. Example: {DEFAULT_OUTPUT_PATHS['cc_surf_vtk']}.", default=None, ) advanced.add_argument( "--softlabels_cc", type=path_or_none, help=f"Path for corpus callosum softlabels, which contains the soft labels of each voxel. " - f"Example: {STANDARD_OUTPUT_PATHS['softlabels_cc']}.", + f"Example: {DEFAULT_OUTPUT_PATHS['softlabels_cc']}.", default=None, ) advanced.add_argument( "--softlabels_fn", type=path_or_none, help=f"Path for fornix softlabels, which contains the soft labels of each voxel. " - f"Example: {STANDARD_OUTPUT_PATHS['softlabels_fn']}.", + f"Example: {DEFAULT_OUTPUT_PATHS['softlabels_fn']}.", default=None, ) advanced.add_argument( "--softlabels_background", type=path_or_none, help=f"Path for background softlabels, which contains the probability of each voxel. " - f"Example: {STANDARD_OUTPUT_PATHS['softlabels_background']}.", + f"Example: {DEFAULT_OUTPUT_PATHS['softlabels_background']}.", default=None, ) ############ END OF OUTPUT PATHS ############ @@ -372,10 +373,10 @@ def options_parse() -> argparse.Namespace: if args.subject_dir: # Create standard FreeSurfer subdirectories if not args.conf_name: - args.conf_name = args.subject_dir / STANDARD_INPUT_PATHS["conf_name"] + args.conf_name = args.subject_dir / DEFAULT_INPUT_PATHS["conf_name"] if not args.aseg_name: - args.aseg_name = args.subject_dir / STANDARD_INPUT_PATHS["aseg_name"] + args.aseg_name = args.subject_dir / DEFAULT_INPUT_PATHS["aseg_name"] all_paths = ("segmentation", "segmentation_in_orig", "cc_measures", "upright_lta", "orient_volume_lta", "cc_surf", "softlabels_cc", "softlabels_fn", "softlabels_background", "cc_mid_measures", "cc_thickness_overlay", @@ -485,7 +486,7 @@ def localize_ac_pc( """ # get center of third ventricle from aseg and map to fsaverage space - third_ventricle_mask = np.asarray(aseg_nib.dataobj) == 4 + third_ventricle_mask = np.asarray(aseg_nib.dataobj) == THIRD_VENTRICLE_LABEL third_ventricle_center = np.argwhere(third_ventricle_mask).mean(axis=0) third_ventricle_center_vox = apply_transform_to_pt(third_ventricle_center, orig_fsaverage_vox2vox, inv=False) @@ -795,7 +796,7 @@ def main( slice_results, slice_io_futures = recon_cc_surf_measures_multi( segmentation=cc_fn_seg_labels, slice_selection=slice_selection, - temp_seg_affine=fsavg_affine, + upright_affine=fsavg_affine, midslices=midslices, ac_coords=ac_coords, pc_coords=pc_coords, @@ -816,7 +817,7 @@ def main( "Large area changes detected between consecutive slices, this is likely due to a segmentation error." ) - # Get middle slice result for backward compatibility + # Get middle slice result middle_slice_result = slice_results[len(slice_results) // 2] if len(middle_slice_result['split_contours']) <= 5: diff --git a/CorpusCallosum/localization/localization_inference.py b/CorpusCallosum/localization/localization_inference.py index d4c0ae29..20c6e0f1 100644 --- a/CorpusCallosum/localization/localization_inference.py +++ b/CorpusCallosum/localization/localization_inference.py @@ -40,7 +40,7 @@ def load_model(device: torch.device) -> DenseNet: Returns ------- DenseNet - Loaded and initialized model in evaluation mode + Loaded and initialized model in evaluation mode. """ # Initialize model architecture (must match training) @@ -115,11 +115,15 @@ def preprocess_volume( Returns ------- dict[str, torch.Tensor] - Dictionary containing preprocessed image tensor + Dictionary containing preprocessed image tensor. """ if transform is None: transform = get_transforms() + # During training we used AC/PC coordinates, but during inference + # we approximate this by the center of the third ventricle. + # Therefore we put in the third ventricle center as dummy AC/PC coordinates + # for cropping the image. sample = {"image": image_volume[None], "AC_center": center_pt[1:][None], "PC_center": center_pt[1:][None]} # Apply transforms @@ -145,15 +149,15 @@ def run_inference( Parameters ---------- model : DenseNet - Trained model for inference + Trained model for inference. image_volume : np.ndarray - Input volume as numpy array + Input volume as numpy array. third_ventricle_center : np.ndarray - Initial center point estimate for cropping + Initial center point estimate for cropping. device : torch.device, optional - Device to run inference on, by default None + Device to run inference on, by default None. transform : transforms.Transform, optional - Custom transform pipeline, by default None + Custom transform pipeline, defaults to preconfigured transforms of `get_transforms`. Returns ------- @@ -168,7 +172,6 @@ def run_inference( """ if device is None: device = next(model.parameters()).device - #device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # prepend zero to third_ventricle_center third_ventricle_center = np.concatenate([np.zeros(1), third_ventricle_center]) @@ -203,13 +206,13 @@ def run_inference_on_slice( Parameters ---------- model : torch.nn.Module - Trained model for AC-PC detection + Trained model for AC-PC detection. image_slice : np.ndarray 3D image mid-slices to run inference on in RAS. center_pt : np.ndarray - Initial center point estimate for cropping + Initial center point estimate for cropping. debug_output : str, optional - Path to save debug visualization, by default None + Path to save debug visualization, by default None. Returns ------- diff --git a/CorpusCallosum/paint_cc_into_pred.py b/CorpusCallosum/paint_cc_into_pred.py index 2a577d1b..f344553b 100644 --- a/CorpusCallosum/paint_cc_into_pred.py +++ b/CorpusCallosum/paint_cc_into_pred.py @@ -25,6 +25,7 @@ from numpy import typing as npt from scipy import ndimage +import FastSurferCNN.utils.logging as logging from CorpusCallosum.data.constants import FORNIX_LABEL, SUBSEGMENT_LABELS from FastSurferCNN.data_loader.conform import is_conform from FastSurferCNN.reduce_to_aseg import reduce_to_aseg_and_save @@ -34,6 +35,8 @@ _T = TypeVar("_T", bound=np.number) +logger = logging.get_logger(__name__) + HELPTEXT = """ Script to add corpus callosum segmentation (CC, FreeSurfer IDs 251-255) to deep-learning prediction (e.g. aparc.DKTatlas+aseg.deep.mgz). @@ -266,7 +269,7 @@ def correct_wm_ventricles( # Command Line options are error checking done here options = argument_parse() - print(f"Reading inputs: {options.input_cc} {options.input_pred}...") + logger.info(f"Reading inputs: {options.input_cc} {options.input_pred}...") cc_seg_image = cast(nib.analyze.SpatialImage, nib.load(options.input_cc)) cc_seg_data = np.asanyarray(cc_seg_image.dataobj) aseg_image = cast(nib.analyze.SpatialImage, nib.load(options.input_pred)) @@ -287,7 +290,7 @@ def correct_wm_ventricles( pred_with_cc = paint_in_cc(aseg_data, cc_seg_data) # Apply WM and ventricle corrections - print("Applying white matter and ventricle corrections...") + logger.info("Applying white matter and ventricle corrections...") fornix_mask = cc_seg_data == FORNIX_LABEL voxel_size = tuple(aseg_image.header.get_zooms()) pred_corrected = correct_wm_ventricles(aseg_data, fornix_mask, voxel_size) @@ -322,8 +325,8 @@ def correct_wm_ventricles( final_wm = np.sum((pred_corrected == 2) | (pred_corrected == 41)) final_ventricles = np.sum((pred_corrected == 4) | (pred_corrected == 43)) - print(f"Final segmentation: CC={final_cc}, Fornix={final_fornix}, WM={final_wm}, Ventricles={final_ventricles}") - print(f"Changes: CC +{final_cc-initial_cc}, Fornix {final_fornix-initial_fornix}, WM {final_wm-initial_wm}") + logger.info(f"Final segmentation: CC={final_cc}, Fornix={final_fornix}, WM={final_wm}, Ventricles={final_ventricles}") + logger.info(f"Changes: CC +{final_cc-initial_cc}, Fornix {final_fornix-initial_fornix}, WM {final_wm-initial_wm}") if rta_fut is not None: _ = rta_fut.result() diff --git a/CorpusCallosum/registration/mapping_helpers.py b/CorpusCallosum/registration/mapping_helpers.py index dcecd9fd..821d4432 100644 --- a/CorpusCallosum/registration/mapping_helpers.py +++ b/CorpusCallosum/registration/mapping_helpers.py @@ -86,13 +86,6 @@ def correct_nodding(ac_pt: npt.NDArray[float], pc_pt: npt.NDArray[float]) -> npt ] ) - # plot vector ac_pc_vec and posterior_vector - # fig, ax = plt.subplots() - # ax.quiver(0, 0, ac_pc_vec[0], ac_pc_vec[1], color='red', label='ac_pc_vec') - # ax.quiver(0, 0, posterior_vector[0], posterior_vector[1], color='blue', label='posterior_vector') - # ax.legend() - # plt.show() - return rotation_matrix @@ -260,7 +253,7 @@ def apply_transform_to_volume( return transformed -def make_affine(simpleITKImage: 'sitk.Image') -> npt.NDArray[float]: +def make_affine(simpleITKImage: sitk.Image) -> npt.NDArray[float]: """Create an affine transformation matrix from a SimpleITK image. Parameters @@ -381,16 +374,16 @@ def interpolate_midplane( Parameters ---------- orig : nib.Nifti1Image - Original image + Original image. orig_fsaverage_vox2vox : np.ndarray - Original to fsaverage space transformation matrix + Original to fsaverage space transformation matrix. slices_to_analyze : int - Number of slices to analyze around midplane + Number of slices to analyze around midplane. Returns ------- np.ndarray - Interpolated image data at midplane + Interpolated image data at midplane. """ # slice_thickness = 9+slices_to_analyze-1 @@ -420,7 +413,7 @@ def interpolate_midplane( from scipy.ndimage import map_coordinates transformed = map_coordinates( - orig.get_fdata(), + np.asarray(orig.dataobj), grid_orig[0:3, :], # use only x,y,z coordinates (drop homogeneous coordinate) order=2, mode="constant", diff --git a/CorpusCallosum/segmentation/segmentation_inference.py b/CorpusCallosum/segmentation/segmentation_inference.py index cb9a5a85..692f0039 100644 --- a/CorpusCallosum/segmentation/segmentation_inference.py +++ b/CorpusCallosum/segmentation/segmentation_inference.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterator from pathlib import Path import nibabel as nib @@ -40,7 +41,7 @@ def load_model(device: torch.device | None = None) -> FastSurferVINN: Returns ------- FastSurferVINN - Loaded and initialized model in evaluation mode + Loaded and initialized model in evaluation mode. """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -94,20 +95,20 @@ def run_inference( Parameters ---------- model : FastSurferVINN - Trained model + Trained model. image_slice : np.ndarray LIA-oriented input image as numpy array of shape (L, I, A). ac_center : np.ndarray - Anterior commissure coordinates + Anterior commissure coordinates. pc_center : np.ndarray - Posterior commissure coordinates + Posterior commissure coordinates. voxel_size : float - Voxel size in mm + Voxel size in mm. device : torch.device or None, optional Device to run inference on, by default None. If None, uses the device of the model. transform : transforms.Transform or None, optional - Custom transform pipeline, by default None + Custom transform pipeline, by default None. Returns ------- @@ -154,13 +155,42 @@ def run_inference( return [x.transpose(0, 2, 3, 1) for x in (labels, _inputs.cpu().numpy(), softlabels)] -def load_validation_data(path): +def load_validation_data( + path: str | Path +) -> tuple[npt.NDArray[str], npt.NDArray[float], npt.NDArray[float], Iterator[int], npt.NDArray[str], list[str]]: + """Load validation data from CSV file and compute label widths. + + Reads a CSV file containing image paths, label paths, and AC/PC coordinates, + then computes the width (number of slices with non-zero labels) for each label file. + Parameters + ---------- + path : str or Path + Path to the CSV file containing validation data. The CSV should have columns: + image, label, AC_center_x, AC_center_y, AC_center_z, + PC_center_x, PC_center_y, PC_center_z. + + Returns + ------- + images : npt.NDArray[str] + Array of image file paths. + ac_centers : npt.NDArray[float] + Array of anterior commissure coordinates (x, y, z). + pc_centers : npt.NDArray[float] + Array of posterior commissure coordinates (x, y, z). + label_widths : Iterator[int] + Iterator yielding the number of slices with non-zero labels for each label file. + labels : npt.NDArray[str] + Array of label file paths. + subj_ids : list[str] + List of subject IDs (from CSV index). + """ import pandas as pd + data = pd.read_csv(path, index_col=0, header=None) - data.columns = ["image", "label", "AC_center_x", "AC_center_y", "AC_center_z", + data.columns = ["image", "label", "AC_center_x", "AC_center_y", "AC_center_z", "PC_center_x", "PC_center_y", "PC_center_z"] - + ac_centers = data[["AC_center_x", "AC_center_y", "AC_center_z"]].values pc_centers = data[["PC_center_x", "PC_center_y", "PC_center_z"]].values images = data["image"].values @@ -168,8 +198,20 @@ def load_validation_data(path): subj_ids = data.index.values.tolist() def _load(label_path: str | Path) -> int: + """Compute the width of non-zero slices in a label image. + + Parameters + ---------- + label_path : str or Path + Path to the label image file + + Returns + ------- + int + Number of slices containing non-zero labels, or total slices if <= 100 + """ label_img = nib.load(label_path) - + if label_img.shape[0] > 100: # check which slices have non-zero values label_data = np.asarray(label_img.dataobj) @@ -179,17 +221,41 @@ def _load(label_path: str | Path) -> int: return last_nonzero - first_nonzero else: return label_img.shape[0] + label_widths = thread_executor().map(_load, data["label"]) - + return images, ac_centers, pc_centers, label_widths, labels, subj_ids -def one_hot_to_label(one_hot, label_ids=None): +def one_hot_to_label( + one_hot: npt.NDArray[float], + label_ids: list[int] | None = None +) -> npt.NDArray[int]: + """Convert one-hot encoded segmentation to label map. + + Converts a one-hot encoded segmentation array to discrete labels by taking + the argmax along the last axis and optionally mapping to specific label values. + + Parameters + ---------- + one_hot : npt.NDArray[float] + One-hot encoded segmentation array of shape (..., num_classes). + label_ids : list[int] or None, optional + List of label IDs to map classes to. If None, defaults to [0, 192, 250]. + The index in this list corresponds to the class index from argmax. + + Returns + ------- + npt.NDArray[int] + Label map with discrete integer labels. + """ if label_ids is None: label_ids = [0, 192, 250] + label = np.argmax(one_hot, axis=3) if label_ids is not None: label = np.asarray(label_ids)[label] + return label @@ -219,11 +285,11 @@ def run_inference_on_slice( Returns ------- results: np.ndarray - Label map after one-hot conversion + Label map after one-hot conversion. inputs: np.ndarray - Preprocessed input image + Preprocessed input image. outputs_soft: npt.NDArray[float] - Softlabel outputs (non-discrete) + Softlabel outputs (non-discrete). """ # add zero in front of AC_center and PC_center diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index 90139496..51cbffe3 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -30,14 +30,14 @@ def find_component_boundaries(labels_arr: npt.NDArray[int], component_id: int) - Parameters ---------- labels_arr : np.ndarray - Labeled array from connected components analysis + Labeled array from connected components analysis. component_id : int - ID of the component to find boundaries for + ID of the component to find boundaries for. Returns ------- np.ndarray - Array of shape (N, 3) containing boundary coordinates + Array of shape (N, 3) containing boundary coordinates. Notes ----- @@ -69,19 +69,21 @@ def find_minimal_connection_path( Parameters ---------- boundary_coords1 : np.ndarray - Boundary coordinates of first component, shape (N1, 3) + Boundary coordinates of first component, shape (N1, 3). boundary_coords2 : np.ndarray - Boundary coordinates of second component, shape (N2, 3) + Boundary coordinates of second component, shape (N2, 3). max_distance : float, default=3.0 - Maximum distance to consider for connection, by default 3.0 + Maximum distance to consider for connection, by default 3.0. Returns ------- tuple[np.ndarray, np.ndarray] or None If a valid connection is found: - - point1 : Coordinates on first boundary - - point2 : Coordinates on second boundary - None if no connection within max_distance is found + + - point1 : Coordinates on first boundary + - point2 : Coordinates on second boundary + + None if no connection within max_distance is found. Notes ----- @@ -113,14 +115,14 @@ def create_connection_line(point1: np.ndarray, point2: np.ndarray) -> list[tuple Parameters ---------- point1 : np.ndarray - Starting point coordinates, shape (3,) + Starting point coordinates, shape (3,). point2 : np.ndarray - Ending point coordinates, shape (3,) + Ending point coordinates, shape (3,). Returns ------- list[tuple[int, int, int]] - List of (x, y, z) coordinates forming the connection line + List of (x, y, z) coordinates forming the connection line. Notes ----- @@ -166,14 +168,14 @@ def connect_nearby_components(seg_arr: np.ndarray, max_connection_distance: floa Parameters ---------- seg_arr : np.ndarray - Input binary segmentation array + Input binary segmentation array. max_connection_distance : float, optional - Maximum distance to connect components, by default 3.0 + Maximum distance to connect components, by default 3.0. Returns ------- np.ndarray - Segmentation array with minimal connections added between nearby components + Segmentation array with minimal connections added between nearby components. Notes ----- @@ -289,16 +291,16 @@ def get_cc_volume_voxel( Parameters ---------- desired_width_mm : int - Desired width of the CC in millimeters + Desired width of the CC in millimeters. cc_mask : np.ndarray - Binary mask of the corpus callosum + Binary mask of the corpus callosum. voxel_size : tuple[float, float, float] - Voxel size in millimeters (x, y, z) + Voxel size in millimeters (x, y, z). Returns ------- float - Volume of the CC in cubic millimeters + Volume of the CC in cubic millimeters. Raises ------ @@ -353,17 +355,15 @@ def get_cc_volume_contour(cc_contours: list[np.ndarray], Parameters ---------- - desired_width_mm : int - Desired width of the CC in millimeters cc_contours : list[np.ndarray] - List of CC contours for each slice in the left-right direction + List of CC contours for each slice in the left-right direction. voxel_size : tuple[float, float, float] - Voxel size in millimeters (x, y, z) + Voxel size in millimeters (x, y, z). Returns ------- float - Volume of the CC in cubic millimeters + Volume of the CC in cubic millimeters. Raises ------ @@ -432,14 +432,14 @@ def extract_largest_connected_component( Parameters ---------- seg_arr : np.ndarray - Input binary segmentation array + Input binary segmentation array. max_connection_distance : float, optional - Maximum distance to connect components, by default 3.0 + Maximum distance to connect components, by default 3.0. Returns ------- np.ndarray - Binary mask of the largest connected component + Binary mask of the largest connected component. Notes ----- @@ -492,7 +492,7 @@ def clean_cc_segmentation( Parameters ---------- seg_arr : npt.NDArray[int] - Input segmentation array with CC (192) and fornix (250) labels + Input segmentation array with CC (192) and fornix (250) labels. max_connection_distance : float, default=3.0 Maximum distance to connect components. diff --git a/CorpusCallosum/shape/cc_mesh.py b/CorpusCallosum/shape/cc_mesh.py index ef372073..3be0f82a 100644 --- a/CorpusCallosum/shape/cc_mesh.py +++ b/CorpusCallosum/shape/cc_mesh.py @@ -197,11 +197,11 @@ def plot_mesh( assert self.v is not None and self.t is not None, "Mesh has not been created yet" if len(self.v) == 0: - print("Warning: No vertices in mesh to plot") + logger.warning("Warning: No vertices in mesh to plot") return if len(self.t) == 0: - print("Warning: No faces in mesh to plot") + logger.warning("Warning: No faces in mesh to plot") return # Define available colormaps @@ -236,7 +236,7 @@ def plot_mesh( # Select the colormap if colormap not in colormaps: - print(f"Warning: Unknown colormap '{colormap}'. Using 'red_to_blue' instead.") + logger.warning(f"Warning: Unknown colormap '{colormap}'. Using 'red_to_blue' instead.") colormap = "red_to_blue" selected_colormap = colormaps[colormap] @@ -741,7 +741,7 @@ def create_mesh(self, lr_center: float = 0, closed: bool = False, smooth: int = # Filter out None contours and get their indices valid_contours = [(i, c) for i, c in enumerate(self.contours) if c is not None] if not valid_contours: - print("Warning: No valid contours found") + logger.warning("Warning: No valid contours found") self.v = np.array([]) self.t = np.array([]) return diff --git a/CorpusCallosum/shape/cc_metrics.py b/CorpusCallosum/shape/cc_metrics.py index 491191d7..7ae55f52 100644 --- a/CorpusCallosum/shape/cc_metrics.py +++ b/CorpusCallosum/shape/cc_metrics.py @@ -14,6 +14,9 @@ import numpy as np +import FastSurferCNN.utils.logging as logging + +logger = logging.get_logger(__name__) def calculate_cc_index(cc_contour: np.ndarray) -> float: """Calculate CC index based on three perpendicular measurements. @@ -80,8 +83,8 @@ def get_intersections(start_point: np.ndarray, direction: np.ndarray) -> np.ndar middle_ints = get_intersections(midpoint, perpendicular_unit) if len(middle_ints) != 2: - print( - f"WARNING: The perpendicular line should intersect the contour twice, " + logger.warning( + f"The perpendicular line should intersect the contour twice, " f"but it intersects {len(middle_ints)} times" ) @@ -95,54 +98,5 @@ def get_intersections(start_point: np.ndarray, direction: np.ndarray) -> np.ndar cc_index = (anterior_distance + posterior_distance + top_distance) / ap_distance - # fig, ax = plt.subplots(figsize=(8, 6)) - - # # Plot the CC contour - # ax.plot(cc_contour[0], cc_contour[1], 'k-', linewidth=1) - # # add line from last to first - # ax.plot([cc_contour[0,-1], cc_contour[0,0]], [cc_contour[1,-1], cc_contour[1,0]], - # 'k-', linewidth=1) - - # # Plot AP line - # ax.plot([cc_contour[0,anterior_idx], cc_contour[0,posterior_idx]], - # [cc_contour[1,anterior_idx], cc_contour[1,posterior_idx]], - # 'r--', linewidth=1)#, label='Anterior-posterior line') - - # # Plot the three measurement lines - # for i, ints in enumerate(zip(anterior_intersections[:-1], anterior_intersections[1:])): - - # if i != 1: - # ax.plot([ints[0][0], ints[1][0]], [ints[0][1], ints[1][1]], - # 'b-', linewidth=1, label='Measurement line horizontal' if i==0 else None) - - # ax.plot([middle_ints[0,0], middle_ints[1,0]], [middle_ints[0,1], middle_ints[1,1]], - # 'g-', linewidth=1, label='Measurement lines vertical') - - # print(middle_ints[0,], middle_ints[1,1]) - # print(midpoint[1], midpoint[0]) - # ax.plot([middle_ints[0,0], midpoint[0]], [middle_ints[0,1], midpoint[1]], - # 'r--', linewidth=1)#, label='Superior-inferior line') - - # #plt.scatter(midpoint[0], midpoint[1], color='green', s=20) - - # ax.set_aspect('equal') - # ax.legend() - # # add gray background to CC contour - # # Fill the inside of the contour with a gray shade - # from matplotlib.path import Path - # from matplotlib.patches import PathPatch - - # # Create a path from the contour points - # contour_path = Path(np.array([cc_contour[0], cc_contour[1]]).T) - - # # Create a patch from the path and add it to the axes - # patch = PathPatch(contour_path, facecolor='gray', alpha=0.2, edgecolor=None) - # ax.add_patch(patch) - - # # invert x - # ax.invert_xaxis() - # #ax.set_title('CC Index Measurement Lines') - # plt.axis('off') - # plt.show() return cc_index diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/cc_postprocessing.py index a4ed63cf..20bc4988 100644 --- a/CorpusCallosum/shape/cc_postprocessing.py +++ b/CorpusCallosum/shape/cc_postprocessing.py @@ -47,12 +47,12 @@ LIA_ORIENTATION[2,1] = -1 -def create_slice_affine(temp_seg_affine: np.ndarray, slice_idx: int, fsaverage_middle: int) -> np.ndarray: +def create_slice_affine(upright_affine: np.ndarray, slice_idx: int, fsaverage_middle: int) -> np.ndarray: """Create slice-specific affine transformation matrix. Parameters ---------- - temp_seg_affine : np.ndarray + upright_affine : np.ndarray Base 4x4 affine transformation matrix. slice_idx : int Index of the slice to transform. @@ -64,7 +64,7 @@ def create_slice_affine(temp_seg_affine: np.ndarray, slice_idx: int, fsaverage_m np.ndarray Modified 4x4 affine transformation matrix for the specific slice. """ - slice_affine = temp_seg_affine.copy() + slice_affine = upright_affine.copy() slice_affine[0, 3] = -fsaverage_middle + slice_idx return slice_affine @@ -73,7 +73,7 @@ def create_slice_affine(temp_seg_affine: np.ndarray, slice_idx: int, fsaverage_m def recon_cc_surf_measures_multi( segmentation: np.ndarray, slice_selection: SliceSelection, - temp_seg_affine: np.ndarray, + upright_affine: np.ndarray, midslices: np.ndarray, ac_coords: np.ndarray, pc_coords: np.ndarray, @@ -93,8 +93,8 @@ def recon_cc_surf_measures_multi( 3D segmentation array. slice_selection : str Which slices to process ('middle', 'all', or slice number). - temp_seg_affine : np.ndarray - Base affine transformation matrix. + upright_affine : np.ndarray + Base affine transformation matrix (fsaverage, upright space). midslices : np.ndarray Array of mid-sagittal slices. ac_coords : np.ndarray @@ -157,7 +157,7 @@ def recon_cc_surf_measures_multi( num_slices = 1 slice_iterator = [int(slice_selection)] - it_affine = map(partial(create_slice_affine, temp_seg_affine, fsaverage_middle=FSAVERAGE_MIDDLE), slice_iterator) + it_affine = map(partial(create_slice_affine, upright_affine, fsaverage_middle=FSAVERAGE_MIDDLE), slice_iterator) iterator = process_executor().map(_each_slice, iter(slice_iterator), it_affine, chunksize=1) cc_mesh = CCMesh(num_slices=num_slices) @@ -309,8 +309,9 @@ def recon_cc_surf_measure( Returns ------- - dict of measures + measures : dict Dictionary containing measurements if successful, including: + - cc_index : float - Corpus callosum shape index. - circularity : float - Shape circularity measure. - areas : np.ndarray - Areas of subdivided regions. @@ -326,8 +327,11 @@ def recon_cc_surf_measure( - thickness_measurement_points : np.ndarray - Points where thickness was measured. - slice_index : int - Index of the processed slice. contour_with_thickness : np.ndarray + Contour points with thickness information. anterior_endpoint_index : int + Index of the anterior endpoint on the contour. posterior_endpoint_index : int + Index of the posterior endpoint on the contour. Raises ------ diff --git a/CorpusCallosum/shape/cc_subsegment_contour.py b/CorpusCallosum/shape/cc_subsegment_contour.py index 4055a026..c1e37359 100644 --- a/CorpusCallosum/shape/cc_subsegment_contour.py +++ b/CorpusCallosum/shape/cc_subsegment_contour.py @@ -84,7 +84,7 @@ def minimum_bounding_rectangle(points): return rval -def calc_subsegment_area(split_contours: list[npt.NDArray[_TS]]) -> npt.NDArray[_TS]: +def calc_subsegment_areas(split_contours: list[npt.NDArray[_TS]]) -> npt.NDArray[_TS]: """Calculate area of each subsegment using the shoelace formula. Parameters @@ -124,8 +124,9 @@ def subsegment_midline_orthogonal(midline, area_weights, contour, plot=True, ax= Returns ------- - subsegment_area : list[np.ndarray] - split_contours : list[np.ndarray] + subsegment_areas : list of float + List of subsegment areas. + split_contours : list of np.ndarray List of contour arrays for each subsegment. """ @@ -314,7 +315,7 @@ def subsegment_midline_orthogonal(midline, area_weights, contour, plot=True, ax= ax.axis("equal") plt.show() - return calc_subsegment_area(split_contours), split_contours + return calc_subsegment_areas(split_contours), split_contours def hampel_subdivide_contour(contour, num_rays, plot=False, ax=None): @@ -355,11 +356,8 @@ def hampel_subdivide_contour(contour, num_rays, plot=False, ax=None): # get angle of lower edge of rectangle to x-axis angle_lower_edge = np.arctan2( lowest_points[1, 1] - lowest_points[0, 1], lowest_points[1, 0] - lowest_points[0, 0] - ) # % (np.pi) - - # steps = np.pi / num_rays + ) - # print(np.degrees(angle_lower_edge)) # get angles for equally spaced rays angles = np.linspace(-angle_lower_edge, -angle_lower_edge + np.pi, num_rays + 2, endpoint=True) # + np.pi *3 angles = angles[1:-1] @@ -453,7 +451,7 @@ def hampel_subdivide_contour(contour, num_rays, plot=False, ax=None): ax.axis("equal") plt.show() - return calc_subsegment_area(split_contours), split_contours + return calc_subsegment_areas(split_contours), split_contours def subdivide_contour( @@ -519,17 +517,6 @@ def subdivide_contour( if hline_anchor is not None: extremes = (np.array([max_x, hline_anchor[1]]), np.array([min_x, hline_anchor[1]])) - - # only keep x values of extremes and set y 5 mm below most inferior point of contour - # if hline_anchor is None: - # most_inferior_point = np.min(contour[1]) - # extremes = (np.array([extremes[0][0], most_inferior_point - 5]), - # np.array([extremes[1][0], most_inferior_point - 5])) - # else: - # # get y difference between extremes and hline_anchor - # y_diff = extremes[1][1] - hline_anchor[1] - # extremes = (np.array([extremes[0][0], extremes[0][1] - y_diff]), - # np.array([extremes[1][0], extremes[1][1] - y_diff])) else: extremes = (contour[:, min_x_index].copy(), contour[:, max_x_index].copy()) # Calculate the line between the extreme points @@ -557,18 +544,6 @@ def subdivide_contour( intersection = start_point + line_unit_vector * np.dot(hline_anchor - start_point, line_unit_vector) # get distance closest point on line to hline_anchor distance = np.linalg.norm(intersection - hline_anchor) - - # import matplotlib.pyplot as plt - # plt.close() - # fig, ax = plt.subplots(1,1,figsize=(8, 6)) - # ax.plot(contour[0], contour[1], 'b-', label='Original Contour') - # ax.plot(hline_anchor[0], hline_anchor[1], 'ro', label='Hline Anchor') - # ax.plot(intersection[0], intersection[1], 'go', label='Intersection') - # ax.plot(start_point[0], start_point[1], 'bo', label='Start Point') - # ax.plot(end_point[0], end_point[1], 'go', label='End Point') - # ax.legend() - # plt.show() - # move start and end point the same distance start_point = extremes[0] + distance * perp_vector end_point = extremes[1] + distance * perp_vector @@ -639,54 +614,12 @@ def subdivide_contour( (first_intersection[:, None], contour[:, first_index:second_index], second_intersection[:, None]) ) - # import matplotlib.pyplot as plt - # plt.close() - # fig, ax = plt.subplots(1,1,figsize=(8, 6)) - # ax.plot(contour[0], contour[1], 'b-', label='Original Contour') - # ax.plot(contour[0][0], contour[1][0], 'bo', label='First contour point') - # ax.plot(first_intersection[0], first_intersection[1], 'ro', label='First Intersection') - # ax.plot(second_intersection[0], second_intersection[1], 'go', label='Second Intersection') - # # ax.plot(contour[:, :first_index][0], contour[:, :first_index][1]+0.5, 'r-', label='First Segment') - # # ax.plot(contour[:, second_index+1:][0], contour[:, second_index+1:][1]+1, 'g-', label='Second Segment') - # ax.plot(contour[:, first_index:second_index][0], - # contour[:, first_index:second_index][1]+0.5, 'r-', label='Segment') - # ax.plot(start_to_cutoff[0], start_to_cutoff[1], 'g-', label='Start to Cutoff') - # ax.legend() - # plt.show() # connect first and second half split_contours.append(start_to_cutoff) else: raise ValueError("No intersections found, this should not happen") - # if plot: - # import matplotlib.pyplot as plt - # plt.figure(figsize=(8, 6)) - # plt.plot(contour[0], contour[1], 'b-', label='Original Contour') - # plt.plot(extremes[0][0], extremes[0][1], 'rx', markersize=8, label='Start Point') - # plt.plot(extremes[1][0], extremes[1][1], 'gx', markersize=8, label='End Point') - # for i, split_contour in enumerate(split_contours): - # plt.plot(split_contour[0], split_contour[1], label=f'Split {i+1}') - # plt.scatter(split_contour[0], split_contour[1], s=10) # Plot vertices - # plt.title('Split Contours') - # plt.xlabel('X') - # plt.ylabel('Y') - # plt.legend() - # plt.axis('equal') - # plt.show() - - # # same plot but segment are moved apart by 5 mm - # plt.figure(figsize=(8, 6)) - # for i, split_contour in enumerate(split_contours): - # plt.plot(split_contour[0], split_contour[1]+i*5, label=f'Split {i+1}') - # plt.scatter(split_contour[0], split_contour[1]+i*5, s=10) # Plot vertices - # plt.title('Split Contours') - # plt.xlabel('X') - # plt.ylabel('Y') - # plt.legend() - # plt.axis('equal') - # plt.show() - if plot: # make vline at every split point split_points_vlines_start = (np.array(split_points) - perp_vector * 1).T @@ -795,7 +728,7 @@ def subdivide_contour( ax.axis("equal") plt.show() - return calc_subsegment_area(split_contours), split_contours + return calc_subsegment_areas(split_contours), split_contours def transform_to_acpc_standard(contour_ras, ac_pt_ras, pc_pt_ras): @@ -885,7 +818,7 @@ def preprocess_cc(cc_label_nib, paths_csv, subj_id): 2D coordinates of posterior commissure. """ - cc_mask = cc_label_nib.get_fdata() == 192 + cc_mask = np.asarray(cc_label_nib.dataobj) == 192 cc_mask = cc_mask[cc_mask.shape[0] // 2] posterior_commisure_center = paths_csv.loc[subj_id, "PC_center_r":"PC_center_s"].to_numpy().astype(float) diff --git a/CorpusCallosum/shape/cc_thickness.py b/CorpusCallosum/shape/cc_thickness.py index 69bfff04..a7f10d1a 100644 --- a/CorpusCallosum/shape/cc_thickness.py +++ b/CorpusCallosum/shape/cc_thickness.py @@ -297,7 +297,7 @@ def cc_thickness( anterior_endpoint_idx: int, posterior_endpoint_idx: int, n_points: int = 100 -) -> tuple[np.ndarray, np.ndarray]: +) -> tuple[float, float, float, np.ndarray, list[np.ndarray], list[np.ndarray], int, int]: """Calculate corpus callosum thickness using Laplace equation. Parameters @@ -313,10 +313,22 @@ def cc_thickness( Returns ------- - thickness_values : np.ndarray - Array of thickness measurements. - measurement_points : np.ndarray - Array of points where thickness was measured. + midline_length : float + Total length of the midline. + thickness : float + Mean thickness across all level paths. + curvature : float + Mean absolute curvature in degrees. + midline_equidistant : np.ndarray + Equidistant points along the midline. + levelpaths : list[np.ndarray] + Level paths for thickness measurement. + contour_with_thickness : list[np.ndarray] + Contour coordinates with thickness information. + anterior_endpoint_idx : int + Updated index of anterior endpoint. + posterior_endpoint_idx : int + Updated index of posterior endpoint. Notes ----- diff --git a/CorpusCallosum/transforms/localization_transforms.py b/CorpusCallosum/transforms/localization_transforms.py index 12e1e5ac..e129fc82 100644 --- a/CorpusCallosum/transforms/localization_transforms.py +++ b/CorpusCallosum/transforms/localization_transforms.py @@ -28,17 +28,23 @@ class CropAroundACPCFixedSize(RandomizableTransform, MapTransform): Parameters ---------- keys : list[str] - Keys of the data dictionary to apply the transform to + Keys of the data dictionary to apply the transform to. fixed_size : tuple[int, int] - Fixed size of the crop window (width, height) + Fixed size of the crop window (width, height). allow_missing_keys : bool, optional - Whether to allow missing keys in the data dictionary, by default False + Whether to allow missing keys in the data dictionary, by default False. random_translate : int, default=0 Maximum random translation in voxels. + Raises + ------ + ValueError + If the crop boundaries extend outside the image dimensions. + Notes ----- The transform expects the following keys in the data dictionary: + - AC_center : np.ndarray Coordinates of anterior commissure - PC_center : np.ndarray @@ -46,10 +52,6 @@ class CropAroundACPCFixedSize(RandomizableTransform, MapTransform): - image : np.ndarray Input image to crop - Raises - ------ - ValueError - If the crop boundaries extend outside the image dimensions """ def __init__( diff --git a/CorpusCallosum/transforms/segmentation_transforms.py b/CorpusCallosum/transforms/segmentation_transforms.py index 0a11d9b7..2fcf41da 100644 --- a/CorpusCallosum/transforms/segmentation_transforms.py +++ b/CorpusCallosum/transforms/segmentation_transforms.py @@ -25,7 +25,7 @@ class CropAroundACPC(RandomizableTransform, MapTransform): Parameters ---------- keys : list[str] - Keys of the data dictionary to apply the transform to + Keys of the data dictionary to apply the transform to. allow_missing_keys : bool, default=False Whether to allow missing keys in the data dictionary. padding_mm : float, default=10.0 @@ -36,12 +36,14 @@ class CropAroundACPC(RandomizableTransform, MapTransform): Notes ----- The transform expects the following keys in the data dictionary: + - AC_center : np.ndarray Coordinates of anterior commissure - PC_center : np.ndarray Coordinates of posterior commissure - res : float Voxel resolution in mm + """ def __init__(self, keys: list[str], allow_missing_keys: bool = False, @@ -57,12 +59,12 @@ def __call__(self, data: dict) -> dict: Parameters ---------- data : dict - Dictionary containing the data to transform + Dictionary containing the data to transform. Returns ------- dict - Transformed data dictionary + Transformed data dictionary. """ d = dict(data) @@ -110,17 +112,18 @@ class CropAroundACPCtrack(CropAroundACPC): Parameters ---------- keys : list[str] - Keys of the data dictionary to apply the transform to + Keys of the data dictionary to apply the transform to. allow_missing_keys : bool, optional - Whether to allow missing keys in the data dictionary, by default False + Whether to allow missing keys in the data dictionary, by default False. padding_mm : float, optional - Padding around AC-PC region in millimeters, by default 10 + Padding around AC-PC region in millimeters, by default 10. random_translate : float, optional - Maximum random translation in voxels, by default 0 + Maximum random translation in voxels, by default 0. Notes ----- The transform expects the following keys in the data dictionary: + - AC_center : np.ndarray Coordinates of anterior commissure - PC_center : np.ndarray @@ -129,6 +132,7 @@ class CropAroundACPCtrack(CropAroundACPC): Original coordinates of anterior commissure - PC_center_original : np.ndarray Original coordinates of posterior commissure + """ def __call__(self, data: dict) -> dict: @@ -137,12 +141,12 @@ def __call__(self, data: dict) -> dict: Parameters ---------- data : dict - Dictionary containing the data to transform + Dictionary containing the data to transform. Returns ------- dict - Transformed data dictionary with updated AC and PC coordinates + Transformed data dictionary with updated AC and PC coordinates. """ diff --git a/CorpusCallosum/visualization/visualization.py b/CorpusCallosum/visualization/visualization.py index 398fde3d..db2feea2 100644 --- a/CorpusCallosum/visualization/visualization.py +++ b/CorpusCallosum/visualization/visualization.py @@ -30,13 +30,13 @@ def plot_standardized_space( Parameters ---------- ax_row : list[plt.Axes] - Row of axes to plot on (should be length 3) + Row of axes to plot on (should be length 3). vol : np.ndarray - Volume data to visualize + Volume data to visualize. ac_coords : np.ndarray - AC coordinates in standardized space + AC coordinates in standardized space. pc_coords : np.ndarray - PC coordinates in standardized space + PC coordinates in standardized space. Notes ----- @@ -73,25 +73,25 @@ def visualize_coordinate_spaces( Parameters ---------- orig : nibabel.Nifti1Image - Original image volume + Original image volume. upright : np.ndarray - Volume in fsaverage space + Volume in fsaverage space. standardized : np.ndarray - Volume in standardized space + Volume in standardized space. ac_coords_orig : np.ndarray - AC coordinates in original space + AC coordinates in original space. pc_coords_orig : np.ndarray - PC coordinates in original space + PC coordinates in original space. ac_coords_3d : np.ndarray - AC coordinates in fsaverage space + AC coordinates in fsaverage space. pc_coords_3d : np.ndarray - PC coordinates in fsaverage space + PC coordinates in fsaverage space. ac_coords_standardized : np.ndarray - AC coordinates in standardized space + AC coordinates in standardized space. pc_coords_standardized : np.ndarray - PC coordinates in standardized space + PC coordinates in standardized space. output_plot_path : str or Path - Directory to save visualization + Directory to save visualization. Notes ----- @@ -105,7 +105,7 @@ def visualize_coordinate_spaces( ax = ax.T # Original space - using plot_standardized_space - plot_standardized_space(ax[0], orig.get_fdata(), ac_coords_orig, pc_coords_orig) + plot_standardized_space(ax[0], np.asarray(orig.dataobj), ac_coords_orig, pc_coords_orig) ax[0, 0].set_title("Orig") # Fsaverage space @@ -141,9 +141,9 @@ def plot_contours( Parameters ---------- transformed : np.ndarray - Transformed image data + Transformed image data. split_contours : list[np.ndarray], optional - List of contour arrays for each subdivision (ignore contours on None) + List of contour arrays for each subdivision (ignore contours on None). midline_equidistant : np.ndarray, optional Midline points at equidistant spacing (ignore midline on None). levelpaths : list[np.ndarray], optional @@ -155,7 +155,7 @@ def plot_contours( pc_coords : np.ndarray, optional PC coordinates for visualization (ignore PC on None). vox_size : float, optional - Voxel size for scaling + Voxel size for scaling. title : str, default="" Title for the plot. @@ -221,9 +221,9 @@ def plot_midplane(grid_orig: np.ndarray, orig: np.ndarray) -> None: Parameters ---------- grid_orig : np.ndarray - Grid points in original space, shape (3, N) + Grid points in original space, shape (3, N). orig : np.ndarray - Original image for dimension reference + Original image for dimension reference. Notes ----- diff --git a/README.md b/README.md index eda2acd2..8ff862d8 100644 --- a/README.md +++ b/README.md @@ -24,16 +24,20 @@ Modules (all run by default): - the core, outputs anatomical segmentation and cortical parcellation and statistics of 95 classes, mimics FreeSurfer’s DKTatlas. - requires a T1w image ([notes on input images](#requirements-to-input-images)), supports high-res (up to 0.7mm, experimental beyond that). - performs bias-field correction and calculates volume statistics corrected for partial volume effects (skipped if `--no_biasfield` is passed). -2. `cereb:` [CerebNet](CerebNet/README.md) for cerebellum sub-segmentation (deactivate with `--no_cereb`) +2. `cc`: [CorpusCallosum](CorpusCallosum/README.md) for corpus callosum segmentation and shape analysis (deactivate with `--no_cc`) + - requires `asegdkt_segfile` (segmentation) and conformed mri (orig.mgz), outputs CC segmentation, thickness, and shape metrics. + - standardizes brain orientation based on AC/PC landmarks (orient_volume.lta). +3. `cereb:` [CerebNet](CerebNet/README.md) for cerebellum sub-segmentation (deactivate with `--no_cereb`) - requires `asegdkt_segfile`, outputs cerebellar sub-segmentation with detailed WM/GM delineation. - requires a T1w image ([notes on input images](#requirements-to-input-images)), which will be resampled to 1mm isotropic images (no native high-res support). - calculates volume statistics corrected for partial volume effects (skipped if `--no_biasfield` is passed). -3. `hypothal`: [HypVINN](HypVINN/README.md) for hypothalamus subsegmentation (deactivate with `--no_hypothal`) +4. `hypothal`: [HypVINN](HypVINN/README.md) for hypothalamus subsegmentation (deactivate with `--no_hypothal`) - outputs a hypothalamic subsegmentation including 3rd ventricle, c. mammilare, fornix and optic tracts. - a T1w image is highly recommended ([notes on input images](#requirements-to-input-images)), supports high-res (up to 0.7mm, but experimental beyond that). - allows the additional passing of a T2w image with `--t2 `, which will be registered to the T1w image (see `--reg_mode` option). - calculates volume statistics corrected for partial volume effects based on the T1w image (skipped if `--no_bias_field` is passed). + ### Surface reconstruction - approximately 60-90 minutes, `--surf_only` runs only [the surface part](recon_surf/README.md). - supports high-resolution images (up to 0.7mm, experimental beyond that). @@ -125,6 +129,8 @@ All the examples can be found here: [FASTSURFER_EXAMPLES](doc/overview/EXAMPLES. Modules output can be found here: [FastSurfer_Output_Files](doc/overview/OUTPUT_FILES.md) - [Segmentation module](doc/overview/OUTPUT_FILES.md#segmentation-module) - [Cerebnet module](doc/overview/OUTPUT_FILES.md#cerebnet-module) +- [HypVINN module](doc/overview/OUTPUT_FILES.md#hypvinn-module) +- [Corpus Callosum module](doc/overview/OUTPUT_FILES.md#corpus-callosum-module) - [Surface module](doc/overview/OUTPUT_FILES.md#surface-module) @@ -146,7 +152,7 @@ The default device is the GPU. The view-aggregation device can be switched to CP ## Expert usage Individual modules and the surface pipeline can be run independently of the full pipeline script documented in this documentation. -This is documented in READMEs in subfolders, for example: [whole brain segmentation only with FastSurferVINN](FastSurferCNN/README.md), [cerebellum sub-segmentation](CerebNet/README.md), [hypothalamic sub-segmentation](HypVINN/README.md) and [surface pipeline only (recon-surf)](recon_surf/README.md). +This is documented in READMEs in subfolders, for example: [whole brain segmentation only with FastSurferVINN](FastSurferCNN/README.md), [cerebellum sub-segmentation](CerebNet/README.md), [hypothalamic sub-segmentation](HypVINN/README.md), [corpus callosum analysis](CorpusCallosum/README.md) and [surface pipeline only (recon-surf)](recon_surf/README.md). Specifically, the segmentation modules feature options for optimized parallelization of batch processing. diff --git a/doc/api/CorpusCallosum_utils.rst b/doc/api/CorpusCallosum_utils.rst index bf06f78b..f0490a10 100644 --- a/doc/api/CorpusCallosum_utils.rst +++ b/doc/api/CorpusCallosum_utils.rst @@ -7,4 +7,3 @@ CorpusCallosum.utils :toctree: generated/ checkpoint - utils diff --git a/doc/api/index.rst b/doc/api/index.rst index fb5492a8..6a099907 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -18,8 +18,13 @@ FastSurfer API CerebNet.utils.rst CorpusCallosum.rst CorpusCallosum_data.rst + CorpusCallosum_localization.rst + CorpusCallosum_registration.rst + CorpusCallosum_segmentation.rst CorpusCallosum_shape.rst + CorpusCallosum_transforms.rst CorpusCallosum_utils.rst + CorpusCallosum_visualization.rst HypVINN.rst HypVINN.dataloader.rst HypVINN.models.rst diff --git a/doc/overview/FLAGS.md b/doc/overview/FLAGS.md index 3f06d74d..735136fb 100644 --- a/doc/overview/FLAGS.md +++ b/doc/overview/FLAGS.md @@ -6,7 +6,7 @@ The `*fastsurfer-flags*` will usually at least include the subject directory (`- ```bash ... --sd /output --sid test_subject --t1 /data/test_subject_t1.nii.gz --3T ``` -Additionally, you can use `--seg_only` or `--surf_only` to only run a part of the pipeline or `--no_biasfield`, `--no_cereb` and `--no_asegdkt` to switch off individual segmentation modules. +Additionally, you can use `--seg_only` or `--surf_only` to only run a part of the pipeline or `--no_biasfield`, `--no_cereb`, `--no_hypothal`, `--no_cc`, and `--no_asegdkt` to switch off individual segmentation modules. Here, we have also added the `--3T` flag, which tells FastSurfer to register against the 3T atlas which is only relevant for the ICV estimation (eTIV). In the following, we give an overview of the most important options. You can view a [full list of options](FLAGS.md#full-list-of-flags) with @@ -30,6 +30,8 @@ In the following, we give an overview of the most important options. You can vie * `--device`: Select device for neural network segmentation (_auto_, _cpu_, _cuda_, _cuda:_, _mps_), where cuda means Nvidia GPU, you can select which one e.g. "cuda:1". Default: "auto", check GPU and then CPU. "mps" is for native MAC installs to use the Apple silicon (M-chip) GPU. * `--asegdkt_segfile`: Name of the segmentation file, which includes the aparc+DKTatlas-aseg segmentations. Requires an ABSOLUTE Path! Default location: \$SUBJECTS_DIR/\$sid/mri/aparc.DKTatlas+aseg.deep.mgz * `--no_cereb`: Switch off the cerebellum sub-segmentation. +* `--no_hypothal`: Skip the hypothalamus segmentation. +* `--no_cc`: Skip the segmentation and analysis of the corpus callosum. * `--cereb_segfile`: Name of the cerebellum segmentation file. If not provided, this intermediate DL-based segmentation will not be stored, but only the merged segmentation will be stored (see --main_segfile ). Requires an ABSOLUTE Path! Default location: \$SUBJECTS_DIR/\$sid/mri/cerebellum.CerebNet.nii.gz * `--no_biasfield`: Deactivate the biasfield correction and calculation of partial volume-corrected statistics in the segmentation modules. * `--native_image` or `--keepgeom`: **Only supported for `--seg_only`**, segment in native image space (keep orientation, image size and voxel size of the input image), this also includes experimental support for anisotropic images (no extreme anisotropy). diff --git a/doc/overview/OUTPUT_FILES.md b/doc/overview/OUTPUT_FILES.md index 87416ed3..8d7ab7b2 100644 --- a/doc/overview/OUTPUT_FILES.md +++ b/doc/overview/OUTPUT_FILES.md @@ -15,6 +15,31 @@ The segmentation module outputs the files shown in the table below. The two prim | scripts | deep-seg.log | asegdkt | logfile | | stats | aseg+DKT.stats | asegdkt | table of cortical and subcortical segmentation statistics | + +## Corpus Callosum module + +The Corpus Callosum module outputs the files in the table shown below. It creates detailed segmentations and shape analysis of the corpus callosum. + +| directory | filename | module | description | +|:----------------|--------------------------------|--------|--------------------------------------------------------------------------------------------------------------| +| mri | callosum_seg_upright.mgz | cc | corpus callosum segmentation in upright space | +| mri | callosum_seg_aseg_space.mgz | cc | corpus callosum segmentation in conformed image orientation | +| mri | callosum_seg_soft.mgz | cc | corpus callosum soft labels | +| mri | fornix_seg_soft.mgz | cc | fornix soft labels | +| mri | background_seg_soft.mgz | cc | background soft labels | +| mri/transforms | cc_up.lta | cc | transform from original to upright space | +| mri/transforms | orient_volume.lta | cc | transform to standardized space | +| stats | callosum.CC.midslice.json | cc | measurements from the middle sagittal slice (landmarks, area, thickness, etc.) | +| stats | callosum.CC.all_slices.json | cc | comprehensive per-slice analysis (only when using `--slice_selection all`) | +| qc_snapshots | callosum.png | cc | debug visualization of contours and thickness | +| qc_snapshots | callosum_thickness.png | cc | 3D thickness visualization (with `--slice_selection all`) | +| qc_snapshots | corpus_callosum.html | cc | interactive 3D mesh visualization (with `--slice_selection all`) | +| surf | callosum.surf | cc | FreeSurfer surface format (with `--slice_selection all`) | +| surf | callosum.thickness.w | cc | FreeSurfer overlay file containing thickness values (with `--slice_selection all`) | +| surf | callosum_mesh.vtk | cc | VTK format mesh file for 3D visualization (with `--slice_selection all`) | + + + ## Cerebnet module The cerebellum module outputs the files in the table shown below. Unless switched off by the `--no_cereb` argument, this module is automatically run whenever the segmentation module is run. It adds two files, an image with the sub-segmentation of the cerebellum and a text file with summary statistics. @@ -73,4 +98,4 @@ The primary output files are pial, white, and inflated surface files, the thickn | stats | lh.aparc.DKTatlas.mapped.stats, rh.aparc.DKTatlas.mapped.stats | surface | table of cortical parcellation statistics, mapped from ASEGDKT segmentation to the surface | | stats | lh.curv.stats, rh.curv.stats | surface | table of curvature statistics | | stats | wmparc.DKTatlas.mapped.stats | surface | table of white matter segmentation statistics | -| scripts | recon-all.log | surface | logfile | \ No newline at end of file +| scripts | recon-all.log | surface | logfile | diff --git a/doc/overview/index.rst b/doc/overview/index.rst index e41f6593..2fca45ff 100644 --- a/doc/overview/index.rst +++ b/doc/overview/index.rst @@ -10,6 +10,7 @@ User Guide EXAMPLES.md FLAGS.md OUTPUT_FILES.md + modules/index docker SINGULARITY.md MACOS.md diff --git a/doc/scripts/advanced.rst b/doc/scripts/advanced.rst index 82551a7c..d18d755d 100644 --- a/doc/scripts/advanced.rst +++ b/doc/scripts/advanced.rst @@ -7,6 +7,8 @@ Advanced scripts fastsurfercnn cerebnet hypvinn + fastsurfer_cc + cc_visualization recon_surf segstats long_compat_segmentHA From 4d2f6df635636ea8d0794e361ed82f85e7e9291b Mon Sep 17 00:00:00 2001 From: ClePol Date: Wed, 26 Nov 2025 17:55:58 +0100 Subject: [PATCH 33/68] updated logging in paint_cc_into_pred and added missing docfiles --- CorpusCallosum/paint_cc_into_pred.py | 12 ++- CorpusCallosum/registration/__init__.py | 0 doc/api/CorpusCallosum_localization.rst | 9 ++ doc/api/CorpusCallosum_registration.rst | 9 ++ doc/api/CorpusCallosum_segmentation.rst | 10 ++ doc/api/CorpusCallosum_transforms.rst | 10 ++ doc/api/CorpusCallosum_visualization.rst | 9 ++ doc/overview/modules/CC.md | 127 +++++++++++++++++++++++ doc/overview/modules/index.rst | 9 ++ doc/scripts/cc_visualization.rst | 50 +++++++++ doc/scripts/fastsurfer_cc.rst | 11 ++ 11 files changed, 251 insertions(+), 5 deletions(-) create mode 100644 CorpusCallosum/registration/__init__.py create mode 100644 doc/api/CorpusCallosum_localization.rst create mode 100644 doc/api/CorpusCallosum_registration.rst create mode 100644 doc/api/CorpusCallosum_segmentation.rst create mode 100644 doc/api/CorpusCallosum_transforms.rst create mode 100644 doc/api/CorpusCallosum_visualization.rst create mode 100644 doc/overview/modules/CC.md create mode 100644 doc/overview/modules/index.rst create mode 100644 doc/scripts/cc_visualization.rst create mode 100644 doc/scripts/fastsurfer_cc.rst diff --git a/CorpusCallosum/paint_cc_into_pred.py b/CorpusCallosum/paint_cc_into_pred.py index f344553b..0af929e8 100644 --- a/CorpusCallosum/paint_cc_into_pred.py +++ b/CorpusCallosum/paint_cc_into_pred.py @@ -295,7 +295,7 @@ def correct_wm_ventricles( voxel_size = tuple(aseg_image.header.get_zooms()) pred_corrected = correct_wm_ventricles(aseg_data, fornix_mask, voxel_size) - print(f"Writing segmentation with corpus callosum to: {options.output}") + logger.info(f"Writing segmentation with corpus callosum to: {options.output}") pred_with_cc_fin = nib.MGHImage(pred_corrected, aseg_image.affine, aseg_image.header) io_fut = thread_executor().submit(pred_with_cc_fin.to_filename, options.output) @@ -314,10 +314,10 @@ def correct_wm_ventricles( initial_cc = np.sum(mask_in_array(aseg_data, SUBSEGMENT_LABELS)) initial_fornix = np.sum(aseg_data == FORNIX_LABEL) initial_wm = np.sum((aseg_data == 2) | (aseg_data == 41)) - print(f"Initial segmentation: CC={initial_cc}, Fornix={initial_fornix}, WM={initial_wm}") + logger.info(f"Initial segmentation: CC={initial_cc}, Fornix={initial_fornix}, WM={initial_wm}") after_paint_cc = np.sum(mask_in_array(pred_with_cc, SUBSEGMENT_LABELS)) - print(f"After painting CC: {after_paint_cc} CC voxels added") + logger.info(f"After painting CC: {after_paint_cc} CC voxels added") # Count final labels final_cc = np.sum(mask_in_array(pred_corrected, SUBSEGMENT_LABELS)) @@ -325,8 +325,10 @@ def correct_wm_ventricles( final_wm = np.sum((pred_corrected == 2) | (pred_corrected == 41)) final_ventricles = np.sum((pred_corrected == 4) | (pred_corrected == 43)) - logger.info(f"Final segmentation: CC={final_cc}, Fornix={final_fornix}, WM={final_wm}, Ventricles={final_ventricles}") - logger.info(f"Changes: CC +{final_cc-initial_cc}, Fornix {final_fornix-initial_fornix}, WM {final_wm-initial_wm}") + logger.info(f"Final segmentation: CC={final_cc}, Fornix={final_fornix},\ + WM={final_wm}, Ventricles={final_ventricles}") + logger.info(f"Changes: CC +{final_cc-initial_cc}, Fornix {final_fornix-initial_fornix},\ + WM {final_wm-initial_wm}") if rta_fut is not None: _ = rta_fut.result() diff --git a/CorpusCallosum/registration/__init__.py b/CorpusCallosum/registration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/doc/api/CorpusCallosum_localization.rst b/doc/api/CorpusCallosum_localization.rst new file mode 100644 index 00000000..142ad093 --- /dev/null +++ b/doc/api/CorpusCallosum_localization.rst @@ -0,0 +1,9 @@ +CorpusCallosum.localization +============================= + +.. currentmodule:: CorpusCallosum.localization + +.. autosummary:: + :toctree: generated/ + + localization_inference diff --git a/doc/api/CorpusCallosum_registration.rst b/doc/api/CorpusCallosum_registration.rst new file mode 100644 index 00000000..add8daa1 --- /dev/null +++ b/doc/api/CorpusCallosum_registration.rst @@ -0,0 +1,9 @@ +CorpusCallosum.registration +============================ + +.. currentmodule:: CorpusCallosum.registration + +.. autosummary:: + :toctree: generated/ + + mapping_helpers diff --git a/doc/api/CorpusCallosum_segmentation.rst b/doc/api/CorpusCallosum_segmentation.rst new file mode 100644 index 00000000..291c14e3 --- /dev/null +++ b/doc/api/CorpusCallosum_segmentation.rst @@ -0,0 +1,10 @@ +CorpusCallosum.segmentation +============================ + +.. currentmodule:: CorpusCallosum.segmentation + +.. autosummary:: + :toctree: generated/ + + segmentation_inference + segmentation_postprocessing diff --git a/doc/api/CorpusCallosum_transforms.rst b/doc/api/CorpusCallosum_transforms.rst new file mode 100644 index 00000000..44b63315 --- /dev/null +++ b/doc/api/CorpusCallosum_transforms.rst @@ -0,0 +1,10 @@ +CorpusCallosum.transforms +=========================== + +.. currentmodule:: CorpusCallosum.transforms + +.. autosummary:: + :toctree: generated/ + + localization_transforms + segmentation_transforms diff --git a/doc/api/CorpusCallosum_visualization.rst b/doc/api/CorpusCallosum_visualization.rst new file mode 100644 index 00000000..f6801dec --- /dev/null +++ b/doc/api/CorpusCallosum_visualization.rst @@ -0,0 +1,9 @@ +CorpusCallosum.visualization +============================== + +.. currentmodule:: CorpusCallosum.visualization + +.. autosummary:: + :toctree: generated/ + + visualization diff --git a/doc/overview/modules/CC.md b/doc/overview/modules/CC.md new file mode 100644 index 00000000..5795136c --- /dev/null +++ b/doc/overview/modules/CC.md @@ -0,0 +1,127 @@ +# Corpus Callosum Pipeline + +A deep learning-based pipeline for automated segmentation, analysis, and shape analysis of the corpus callosum in brain MRI scans. +Also segments the fornix, localizes the anterior and posterior commissure (AC and PC) and standardizes the orientation of the brain. + +## Overview + +This pipeline combines localization and segmentation deep learning models to: +1. Detect AC (Anterior Commissure) and PC (Posterior Commissure) points +2. Extract and align midplane slices +3. Segment the corpus callosum +4. Perform advanced morphometry for corpus callosum, including subdivision, thickness analysis, and various shape metrics +5. Generate visualizations and measurements + +## Analysis Modes + +The pipeline supports different analysis modes that determine the type of template data generated. + +### 3D Analysis + +When running the main pipeline with `--slice_selection all` and `--save_template`, a complete 3D template is generated: + +```bash +# Generate 3D template data +python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ + --slice_selection all \ + --save_template /data/templates/sub001 +``` + +This creates: +- `contours.txt`: Multi-slice contour data for 3D reconstruction +- `thickness_values.txt`: Thickness measurements across all slices +- `measurement_points.txt`: 3D vertex indices for thickness measurements + +**Benefits:** +- Enables volumetric thickness analysis +- Supports advanced 3D visualizations with proper surface topology +- Creates FreeSurfer-compatible overlay files for integration with other tools + +For visualization instructions and outputs, see the [cc_visualization.py documentation](../../scripts/cc_visualization.rst). + +### 2D Analysis + +When using `--slice_selection middle` or a specific slice number with `--save_template`: + +```bash +# Generate 2D template data (middle slice) +python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ + --slice_selection middle \ + --save_template /data/templates/sub001 +``` + +**Benefits:** +- Faster processing for single-slice analysis +- 2D visualization is most suitable for displaying downstream statistics +- Compatibility with classical corpus callosum studies + +For 2D visualization instructions and outputs, see the [cc_visualization.py documentation](../../scripts/cc_visualization.rst). + +### Choosing Analysis Mode + +**Use 3D Analysis (`--slice_selection all`) when:** +- You need complete volumetric analysis +- Surface-based visualization is required +- Integration with FreeSurfer workflows is needed +- Comprehensive thickness mapping across the entire corpus callosum is desired + +**Use 2D Analysis (`--slice_selection middle` or specific slice) when:** +- Traditional single-slice morphometry is sufficient +- Faster processing is preferred +- Focus is on mid-sagittal cross-sectional measurements +- Compatibility with classical corpus callosum studies is needed + +**Note:** The default behavior is `--slice_selection all` for comprehensive 3D analysis. Use `--slice_selection middle` to process only the middle slice for faster, traditional 2D analysis. + +## JSON Output Structure + +The pipeline generates two main JSON files with detailed measurements and analysis results: + +### `stats/callosum.CC.midslice.json` (Middle Slice Analysis) + +This file contains measurements from the middle sagittal slice and includes: + +**Shape Measurements (single values):** +- `total_area`: Total corpus callosum area (mm²) +- `total_perimeter`: Total perimeter length (mm) +- `circularity`: Shape circularity measure (4π × area / perimeter²) +- `cc_index`: Corpus callosum shape index (length/width ratio) +- `midline_length`: Length along the corpus callosum midline (mm) +- `curvature`: Average curve of the midline (degrees), measured by angle between it's sub-segements + +**Subdivisions** +- `areas`: Areas of CC using an improved Hofer-Frahm sub-division method (mm²). This gives more consistent sub-segemnts while preserving the original ratios. + +**Thickness Analysis:** +- `thickness`: Average corpus callosum thickness (mm) +- `thickness_profile`: Thickness profile (mm) of the corpus callosum slice (100 thickness values by default, listed from anterior to posterior CC ends) + + +**Volume Measurements (when multiple slices processed):** +- `cc_5mm_volume`: Total CC volume within 5mm slab using voxel counting (mm³) +- `cc_5mm_volume_pv_corrected`: Volume with partial volume correction using CC contours (mm³) + +**Anatomical Landmarks:** +- `ac_center`: Anterior commissure coordinates in original image space +- `pc_center`: Posterior commissure coordinates in original image space +- `ac_center_oriented_volume`: AC coordinates in standardized space (orient_volume.lta) +- `pc_center_oriented_volume`: PC coordinates in standardized space (orient_volume.lta) +- `ac_center_upright`: AC coordinates in upright space (cc_up.lta) +- `pc_center_upright`: PC coordinates in upright space (cc_up.lta) + +### `stats/callosum.CC.all_slices.json` (Multi-Slice Analysis) + +This file contains comprehensive per-slice analysis when using `--slice_selection all`: + +**Global Parameters:** +- `slices_in_segmentation`: Total number of slices in the segmentation volume +- `voxel_size`: Voxel dimensions [x, y, z] in mm +- `subdivision_method`: Method used for anatomical subdivision +- `num_thickness_points`: Number of points used for thickness estimation +- `subdivision_ratios`: Subdivision fractions used for regional analysis +- `contour_smoothing`: Gaussian sigma used for contour smoothing +- `slice_selection`: Slice selection mode used + +**Per-Slice Data (`slices` array):** + +Each slice entry contains the shape measurements, thickness analysis and sub-divisions as described above. diff --git a/doc/overview/modules/index.rst b/doc/overview/modules/index.rst new file mode 100644 index 00000000..17b1cc45 --- /dev/null +++ b/doc/overview/modules/index.rst @@ -0,0 +1,9 @@ +Modules +======= + +FastSurfer includes several specialized deep learning modules that can be run independently or as part of the main pipeline. These modules provide detailed sub-segmentations and analyses for specific brain regions. + +.. toctree:: + :maxdepth: 2 + + CC diff --git a/doc/scripts/cc_visualization.rst b/doc/scripts/cc_visualization.rst new file mode 100644 index 00000000..068a5a2c --- /dev/null +++ b/doc/scripts/cc_visualization.rst @@ -0,0 +1,50 @@ +CorpusCallosum: cc_visualization.py +=================================== + +.. argparse:: + :module: CorpusCallosum.cc_visualization + :func: make_parser + :prog: cc_visualization.py + +Usage Examples +-------------- + +3D Visualization +~~~~~~~~~~~~~~~~ + +To visualize a 3D template generated by ``fastsurfer_cc.py`` (using ``--slice_selection all --save_template ...``): + +.. code-block:: bash + + python3 cc_visualization.py \ + --contours /data/templates/sub001/contours.txt \ + --thickness /data/templates/sub001/thickness_values.txt \ + --measurement_points /data/templates/sub001/measurement_points.txt \ + --output_dir /data/visualizations/sub001 + +2D Visualization +~~~~~~~~~~~~~~~~ + +To visualize a 2D template (using ``--slice_selection middle --save_template ...``): + +.. code-block:: bash + + python3 cc_visualization.py \ + --thickness /data/templates/sub001/thickness_values.txt \ + --measurement_points /data/templates/sub001/measurement_points.txt \ + --output_dir /data/visualizations/sub001 \ + --twoD + +Outputs +------- + +3D Mode Outputs (default): + - ``cc_mesh.vtk``: VTK format mesh file for 3D visualization + - ``cc_mesh.fssurf``: FreeSurfer surface format + - ``cc_mesh_overlay.curv``: FreeSurfer overlay file with thickness values + - ``cc_mesh.html``: Interactive 3D mesh visualization + - ``cc_mesh_snap.png``: Snapshot image of the 3D mesh + - ``midslice_2d.png``: 2D visualization of the middle slice + +2D Mode Outputs (when ``--twoD`` is specified): + - ``cc_thickness_2d.png``: 2D contour visualization with thickness colormap diff --git a/doc/scripts/fastsurfer_cc.rst b/doc/scripts/fastsurfer_cc.rst new file mode 100644 index 00000000..3f73ebba --- /dev/null +++ b/doc/scripts/fastsurfer_cc.rst @@ -0,0 +1,11 @@ +CorpusCallosum: fastsurfer_cc.py +================================ + +.. argparse:: + :module: CorpusCallosum.fastsurfer_cc + :func: make_parser + :prog: fastsurfer_cc.py + +.. include:: ../overview/modules/CC.md + :parser: myst_parser.sphinx_ + From 6fb08c59f2d57a27f900aa82955a5046b49584fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Wed, 26 Nov 2025 17:18:25 +0100 Subject: [PATCH 34/68] Improve the left_right masking. improve type annotations of brainvolstats. Fix CC measure computation Fix conform check in paint_cc_in_pred --- CorpusCallosum/paint_cc_into_pred.py | 39 +++++++++++++++++----------- run_fastsurfer.sh | 28 ++++++++++---------- 2 files changed, 38 insertions(+), 29 deletions(-) diff --git a/CorpusCallosum/paint_cc_into_pred.py b/CorpusCallosum/paint_cc_into_pred.py index 0af929e8..f39a1c29 100644 --- a/CorpusCallosum/paint_cc_into_pred.py +++ b/CorpusCallosum/paint_cc_into_pred.py @@ -17,6 +17,7 @@ import argparse import sys +from functools import partial from pathlib import Path from typing import TypeVar, cast @@ -160,16 +161,15 @@ def correct_wm_ventricles( # Combine all WM labels all_wm_mask = (aseg_cc == 2) | (aseg_cc == 41) - # 1. Fill space between CC and ventricles # Only fill small gaps (up to 3 voxels) between CC and ventricle boundaries #for ventricle_label, ventricle_mask in [(4, left_ventricle_mask), (43, right_ventricle_mask)]: # Process each slice independently for x in range(corrected_pred.shape[0]): - cc_slice = cc_mask + cc_slice = cc_mask[x] #vent_slice = ventricle_mask - all_wm_slice = all_wm_mask + all_wm_slice = all_wm_mask[x] if all_wm_slice.any() and cc_slice.any(): @@ -185,7 +185,6 @@ def correct_wm_ventricles( if np.any(component_mask & cc_dilated): corrected_pred[x][component_mask] = 0 # Set to background - if fornix_mask[x].any(): fornix_slice = fornix_mask[x] # count WM labels overlapping with fornix @@ -193,7 +192,6 @@ def correct_wm_ventricles( right_wm_overlap = np.sum(fornix_slice & (aseg_cc == 41)) corrected_pred[x][fornix_slice] = 2 + (left_wm_overlap > right_wm_overlap) * 39 # Left WM / Right WM - vent_slice = all_ventricle_mask potential_fill = np.asarray([False]) if cc_slice.any() and vent_slice.any(): @@ -261,7 +259,6 @@ def correct_wm_ventricles( corrected_pred[x, y, :][gap_mask & (corrected_pred[x, y, :] == 0)] = vent_label - return corrected_pred @@ -269,21 +266,33 @@ def correct_wm_ventricles( # Command Line options are error checking done here options = argument_parse() + logging.setup_logging() + logger.info(f"Reading inputs: {options.input_cc} {options.input_pred}...") cc_seg_image = cast(nib.analyze.SpatialImage, nib.load(options.input_cc)) cc_seg_data = np.asanyarray(cc_seg_image.dataobj) aseg_image = cast(nib.analyze.SpatialImage, nib.load(options.input_pred)) aseg_data = np.asanyarray(aseg_image.dataobj) - cc_conformed = is_conform(cc_seg_image, vox_size=None, img_size=None, verbose=False) - pred_conformed = is_conform(aseg_image, vox_size=None, img_size=None, dtype=np.integer, verbose=False) - if not cc_conformed: - sys.exit("Error: CC input image is not conformed (LIA orientation, uint8 dtype). \ - Please conform the image using the conform.py script.") - if not pred_conformed: - sys.exit("Error: Prediction input image is not conformed (LIA orientation, integer dtype). \ - Please conform the image using the conform.py script.") - if not np.allclose(cc_conformed, pred_conformed): + def _is_conform(img, dtype, verbose): + return is_conform(img, vox_size=None, img_size=None, verbose=verbose, dtype=dtype) + + conform_args = (cc_seg_image, aseg_image), (np.uint8, np.integer) + conform_checks = list(thread_executor().map(partial(_is_conform, verbose=False), *conform_args)) + + if not all(conform_checks): + names = [] + dtypes = [] + for conform_check, img, dtype, name in zip(conform_checks, *conform_args, ("CC", "Prediction"), strict=True): + if not conform_check: + _is_conform(img, dtype, verbose=True) + names.append(name) + dtypes.append(dtype.name if hasattr(dtype, "name") else str(dtype)) + sys.exit( + f"Error: {' and '.join(names)} input image is not conformed (LIA orientation, {'/'.join(dtypes)} dtype). " + "Please conform the image(s) using the conform.py script." + ) + if not np.allclose(cc_seg_image.affine, aseg_image.affine): sys.exit("Error: The affine matrices of the aseg and the corpus callosum images are not the same.") # Paint CC into prediction diff --git a/run_fastsurfer.sh b/run_fastsurfer.sh index 367c2a31..12df8d07 100755 --- a/run_fastsurfer.sh +++ b/run_fastsurfer.sh @@ -1111,7 +1111,7 @@ then # generate file names of for the analysis asegdkt_withcc_segfile="$(add_file_suffix "$asegdkt_segfile" "withCC")" asegdkt_withcc_vinn_statsfile="$(add_file_suffix "$asegdkt_vinn_statsfile" "withCC")" - aseg_auto_statsfile="$(add_file_suffix "$aseg_auto_statsfile" "withCC")" + aseg_auto_statsfile="$(basename "$aseg_vinn_statsfile")/aseg.auto.mgz" # note: callosum manedit currently only affects inpainting and not internal FastSurferCC processing (surfaces etc) callosum_seg_manedit="$(add_file_suffix "$callosum_seg" "manedit")" # generate callosum segmentation, mesh, shape and downstream measure files @@ -1132,25 +1132,28 @@ then if [[ "$run_biasfield" == 1 ]] then - # TODO: decide how to measure the size of the white matter, maybe import measures from previous? - # TODO: decide whether to include the fornix PV-corrected volume - # PV list here and not asegdkt_segfile: 192 Fornix cmd=($python "${fastsurfercnndir}/segstats.py" --segfile "$asegdkt_withcc_segfile" --normfile "$norm_name" --lut "$fastsurfercnndir/config/FreeSurferColorLUT.txt" --sd "${sd}" --sid "${subject}" --ids 2 4 5 7 8 10 11 12 13 14 15 16 17 18 24 26 28 31 41 43 44 46 47 49 50 51 52 53 - 54 58 60 63 77 192 251 252 253 254 255 + 54 58 60 63 77 251 252 253 254 255 1002 1003 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1034 1035 2002 2003 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2034 2035 --threads "$threads_seg" --empty --excludeid 0 - --segstatsfile "$aseg_withcc_vinn_statsfile" + --segstatsfile "$asegdkt_withcc_vinn_statsfile" measures # the following measures are unaffected by CC and do not need to be recomputed - --import Mask --file "$asegdkt_vinn_statsfile" - # recompute the measures based on "better" volumes: - --compute BrainSeg BrainSegNotVent SupraTentorial SupraTentorialNotVent - SubCortGray rhCerebralWhiteMatter lhCerebralWhiteMatter CerebralWhiteMatter + --import SubCortGray Mask + ) + if [[ "$run_talairach_registration" == "true" ]] + then + cmd+=("EstimatedTotalIntraCranialVol" "BrainSegVol-to-eTIV" "MaskVol-to-eTIV") + fi + cmd+=(--file "$asegdkt_vinn_statsfile" + # recompute the measures changes coming from CC inpainting (only SubCortGray does not change) + --compute BrainSeg BrainSegNotVent SupraTentorial SupraTentorialNotVent + rhCerebralWhiteMatter lhCerebralWhiteMatter CerebralWhiteMatter ) echo_quoted "${cmd[@]}" "${cmd[@]}" @@ -1166,14 +1169,11 @@ then if [[ "$run_biasfield" == 1 ]] then { - # TODO: decide how to measure the size of the white matter - # TODO: decide whether to include the fornix PV-corrected volume - # PV list here and not asegdkt_segfile: 192 Fornix cmd=($python "${fastsurfercnndir}/segstats.py" --segfile "$aseg_auto_segfile" --normfile "$norm_name" --lut "$fastsurfercnndir/config/FreeSurferColorLUT.txt" --sd "${sd}" --sid "${subject}" --threads "$threads_seg" --empty --excludeid 0 --ids 2 4 3 5 7 8 10 11 12 13 14 15 16 17 18 24 26 28 31 41 42 43 44 46 47 49 50 51 52 53 54 58 60 63 77 - 192 251 252 253 254 255 + 251 252 253 254 255 --segstatsfile "$aseg_auto_statsfile" measures --import "all" --file "$asegdkt_withcc_vinn_statsfile" ) From d2bee058def7fea2b633b705b35e73fbee9546fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Thu, 27 Nov 2025 18:16:36 +0100 Subject: [PATCH 35/68] Fix the CorpusCallosum documentation --- CorpusCallosum/README.md | 5 ++--- ...Callosum_data.rst => CorpusCallosum.data.rst} | 0 ...ation.rst => CorpusCallosum.localization.rst} | 0 ...ation.rst => CorpusCallosum.registration.rst} | 0 ...ation.rst => CorpusCallosum.segmentation.rst} | 0 ...llosum_shape.rst => CorpusCallosum.shape.rst} | 0 ...nsforms.rst => CorpusCallosum.transforms.rst} | 0 ...llosum_utils.rst => CorpusCallosum.utils.rst} | 0 ...tion.rst => CorpusCallosum.visualization.rst} | 0 doc/api/index.rst | 16 ++++++++-------- doc/scripts/fastsurfer_cc.rst | 14 ++++++++++++-- 11 files changed, 22 insertions(+), 13 deletions(-) rename doc/api/{CorpusCallosum_data.rst => CorpusCallosum.data.rst} (100%) rename doc/api/{CorpusCallosum_localization.rst => CorpusCallosum.localization.rst} (100%) rename doc/api/{CorpusCallosum_registration.rst => CorpusCallosum.registration.rst} (100%) rename doc/api/{CorpusCallosum_segmentation.rst => CorpusCallosum.segmentation.rst} (100%) rename doc/api/{CorpusCallosum_shape.rst => CorpusCallosum.shape.rst} (100%) rename doc/api/{CorpusCallosum_transforms.rst => CorpusCallosum.transforms.rst} (100%) rename doc/api/{CorpusCallosum_utils.rst => CorpusCallosum.utils.rst} (100%) rename doc/api/{CorpusCallosum_visualization.rst => CorpusCallosum.visualization.rst} (100%) diff --git a/CorpusCallosum/README.md b/CorpusCallosum/README.md index a417dd84..9e945e35 100644 --- a/CorpusCallosum/README.md +++ b/CorpusCallosum/README.md @@ -10,8 +10,7 @@ For detailed documentation, please refer to: ## Quickstart ```bash -python3 fastsurfer_cc.py --subject_dir /path/to/fastsurfer/output --verbose +python3 fastsurfer_cc.py --sd /path/to/fastsurfer/output --sid test-case --verbose ``` -Gives all standard outputs. Then corpus callosum morphometry can be found at `stats/callosum.CC.midslice.json`, including 100 thickness measurements and areas of sub-segments. -Visualization will be placed in `/path/to/fastsurfer/output/qc_snapshots`. +Gives all standard outputs. The corpus callosum morphometry can be found at `stats/callosum.CC.midslice.json` including 100 thickness measurements and the areas of sub-segments. diff --git a/doc/api/CorpusCallosum_data.rst b/doc/api/CorpusCallosum.data.rst similarity index 100% rename from doc/api/CorpusCallosum_data.rst rename to doc/api/CorpusCallosum.data.rst diff --git a/doc/api/CorpusCallosum_localization.rst b/doc/api/CorpusCallosum.localization.rst similarity index 100% rename from doc/api/CorpusCallosum_localization.rst rename to doc/api/CorpusCallosum.localization.rst diff --git a/doc/api/CorpusCallosum_registration.rst b/doc/api/CorpusCallosum.registration.rst similarity index 100% rename from doc/api/CorpusCallosum_registration.rst rename to doc/api/CorpusCallosum.registration.rst diff --git a/doc/api/CorpusCallosum_segmentation.rst b/doc/api/CorpusCallosum.segmentation.rst similarity index 100% rename from doc/api/CorpusCallosum_segmentation.rst rename to doc/api/CorpusCallosum.segmentation.rst diff --git a/doc/api/CorpusCallosum_shape.rst b/doc/api/CorpusCallosum.shape.rst similarity index 100% rename from doc/api/CorpusCallosum_shape.rst rename to doc/api/CorpusCallosum.shape.rst diff --git a/doc/api/CorpusCallosum_transforms.rst b/doc/api/CorpusCallosum.transforms.rst similarity index 100% rename from doc/api/CorpusCallosum_transforms.rst rename to doc/api/CorpusCallosum.transforms.rst diff --git a/doc/api/CorpusCallosum_utils.rst b/doc/api/CorpusCallosum.utils.rst similarity index 100% rename from doc/api/CorpusCallosum_utils.rst rename to doc/api/CorpusCallosum.utils.rst diff --git a/doc/api/CorpusCallosum_visualization.rst b/doc/api/CorpusCallosum.visualization.rst similarity index 100% rename from doc/api/CorpusCallosum_visualization.rst rename to doc/api/CorpusCallosum.visualization.rst diff --git a/doc/api/index.rst b/doc/api/index.rst index 6a099907..ef022059 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -17,14 +17,14 @@ FastSurfer API CerebNet.models.rst CerebNet.utils.rst CorpusCallosum.rst - CorpusCallosum_data.rst - CorpusCallosum_localization.rst - CorpusCallosum_registration.rst - CorpusCallosum_segmentation.rst - CorpusCallosum_shape.rst - CorpusCallosum_transforms.rst - CorpusCallosum_utils.rst - CorpusCallosum_visualization.rst + CorpusCallosum.data.rst + CorpusCallosum.localization.rst + CorpusCallosum.registration.rst + CorpusCallosum.segmentation.rst + CorpusCallosum.shape.rst + CorpusCallosum.transforms.rst + CorpusCallosum.utils.rst + CorpusCallosum.visualization.rst HypVINN.rst HypVINN.dataloader.rst HypVINN.models.rst diff --git a/doc/scripts/fastsurfer_cc.rst b/doc/scripts/fastsurfer_cc.rst index 3f73ebba..d2f5fcbc 100644 --- a/doc/scripts/fastsurfer_cc.rst +++ b/doc/scripts/fastsurfer_cc.rst @@ -1,11 +1,21 @@ CorpusCallosum: fastsurfer_cc.py ================================ +.. note:: + We recommend to run FastSurfer-CC with the standard `run_fastsurfer.sh` interfaces ! + + +.. + [Note] To tell sphinx where in the documentation CorpusCallosum/README.md can be linked to, it needs to be included somewhere + +.. include:: ../../CorpusCallosum/README.md + :parser: fix_links.parser + :start-line: 1 + .. argparse:: :module: CorpusCallosum.fastsurfer_cc :func: make_parser :prog: fastsurfer_cc.py .. include:: ../overview/modules/CC.md - :parser: myst_parser.sphinx_ - + :parser: fix_links.parser From 97a7c8aca61afb474a4d19926784a3f592c9f0a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Fri, 28 Nov 2025 17:40:45 +0100 Subject: [PATCH 36/68] Remove --qc_output_dir and related functionality to simplify the fastsurfer_cc.py CLI interface and internal code complexity. --- CorpusCallosum/data/constants.py | 6 +-- CorpusCallosum/fastsurfer_cc.py | 66 ++++---------------------------- run_fastsurfer.sh | 6 ++- 3 files changed, 15 insertions(+), 63 deletions(-) diff --git a/CorpusCallosum/data/constants.py b/CorpusCallosum/data/constants.py index 77e4868d..74580931 100644 --- a/CorpusCallosum/data/constants.py +++ b/CorpusCallosum/data/constants.py @@ -47,9 +47,9 @@ "upright_lta": "mri/transforms/cc_up.lta", # lta transform from orig to upright space "orient_volume_lta": "mri/transforms/orient_volume.lta", # lta transform from orig to upright+acpc corrected space ## qc - "qc_image": "{qc_output_dir}/callosum.png", # debug image of cc contours - "thickness_image": "{qc_output_dir}/callosum.thickness.png", # whippersnappy 3D image of cc thickness - "cc_html": "{qc_output_dir}/corpus_callosum.html", # plotly cc visualization + "qc_image": None, #"callosum.png", # debug image of cc contours + "thickness_image": None, # "callosum.thickness.png", # whippersnappy 3D image of cc thickness + "cc_html": None, # "corpus_callosum.html", # plotly cc visualization ## surface "cc_surf": "surf/callosum.surf", # cc surface file "cc_thickness_overlay": "surf/callosum.thickness.w", # cc surface overlay file diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 82a81368..fa190091 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -72,42 +72,6 @@ _TPathLike = TypeVar("_TPathLike", str, Path, Literal[None]) - -class ReplaceQCOutputDir(Path): - """ - A helper class to validate `qc_output_dir` dependent paths. - - Replaces {qc_output_dir} at the start of filename with the correct qc_output_dir. - Also returns None, if qc_output_dir was None. - """ - - def __init__(self, a: Path | str | None): - if a is None: - a = "{None}" - if "{qc_output_dir}" in str(a).removeprefix("{qc_output_dir}/"): - raise ValueError("If the argument contains {qc_output_dir}, it must start with '{qc_output_dir}/'!") - super().__init__(a) - - def replace_qc_dir(self, qc_output_dir: _TPathLike) -> Path | None: - """ - Helper function to replace {qc_output_dir} at the start of filename with the correct qc_output_dir. - - Also returns None, if qc_output_dir was None. - - Notes - ----- - This function implements - """ - if str(self) == "{None}": - return None - elif "{qc_output_dir}" not in str(self): - return self - elif qc_output_dir is None: - return None - - return Path(str(self).replace("{qc_output_dir}", str(qc_output_dir))) - - class ArgumentDefaultsHelpFormatter(HelpFormatter): """Help message formatter which adds default values to argument help.""" @@ -214,14 +178,6 @@ def _slice_selection(a: str) -> SliceSelection: "relative to the subject_dir defined via --sd and --sid!", ) add_arguments(advanced, ["threads"]) - advanced.add_argument( - "--qc_output_dir", - type=path_or_none, - required=False, - default=None, - help="Enables quality control output (paths starting with {qc_output_dir} by default) and sets {qc_output_dir} " - "(the FastSurfer standard is 'qc_snapshots' to save these files in subject_dir/qc_snapshots).", - ) advanced.add_argument( "--upright_volume", type=path_or_none, @@ -268,8 +224,8 @@ def _slice_selection(a: str) -> SliceSelection: ) advanced.add_argument( "--qc_image", - type=ReplaceQCOutputDir, - help="Path for QC visualization image (if it starts with {qc_output_dir}, that is replace by --qc_output_dir).", + type=path_or_none, + help="Path for QC visualization image .", default=DEFAULT_OUTPUT_PATHS["qc_image"], ) advanced.add_argument( @@ -281,8 +237,8 @@ def _slice_selection(a: str) -> SliceSelection: ) advanced.add_argument( "--thickness_image", - type=ReplaceQCOutputDir, - help="Path for thickness image (if it starts with {qc_output_dir}, that is replace by --qc_output_dir).", + type=path_or_none, + help="Path for thickness image.", default=DEFAULT_OUTPUT_PATHS["thickness_image"], ) advanced.add_argument( @@ -301,9 +257,8 @@ def _slice_selection(a: str) -> SliceSelection: advanced.add_argument( "--cc_interactive_html", "--cc_html", dest="cc_html", - type=ReplaceQCOutputDir, - help="Path to the corpus callosum interactive 3D visualization HTML file (if it starts with {qc_output_dir}, " - "that is replace by --qc_output_dir).", + type=path_or_none, + help="Path to the corpus callosum interactive 3D visualization HTML file.", default=DEFAULT_OUTPUT_PATHS["cc_html"], ) advanced.add_argument( @@ -384,12 +339,9 @@ def options_parse() -> argparse.Namespace: # Create parent directories for all output paths for path_name in all_paths: - path: ReplaceQCOutputDir | Path | None = getattr(args, path_name, None) - if isinstance(path, ReplaceQCOutputDir): - path = path.replace_qc_dir(getattr(args, "qc_output_dir", None)) + path: Path | None = getattr(args, path_name, None) if isinstance(path, Path) and not args.subject_dir and not path.is_absolute(): parser.error(f"Must specify --sd and --sid if any path is relative but {path} for {path_name} is relative.") - setattr(args, path_name, path) return args @@ -565,7 +517,6 @@ def main( aseg_name: str | Path, subject_dir: str | Path, slice_selection: SliceSelection = "middle", - qc_output_dir: str | Path | None = None, num_thickness_points: int = 100, subdivisions: list[float] | None = None, subdivision_method: SubdivisionMethod = "shape", @@ -604,8 +555,6 @@ def main( FastSurfer/FreeSurfer subject directory and directory for output files. slice_selection : "middle", "all" or int, default="middle" Which slices to process. - qc_output_dir : str or Path, optional - Directory for quality control outputs, activates qc_image, thickness_image, cc_html. num_thickness_points : int, default=100 Number of points for thickness estimation. subdivisions : list[float], optional @@ -978,7 +927,6 @@ def main( aseg_name=options.aseg_name, subject_dir=options.subject_dir, slice_selection=options.slice_selection, - qc_output_dir=options.qc_output_dir, num_thickness_points=options.num_thickness_points, subdivisions=list(options.subdivisions), # default value is type _fmt_list (does not pickle) subdivision_method=str(options.subdivision_method), # default value is type do not print (does not pickle) diff --git a/run_fastsurfer.sh b/run_fastsurfer.sh index 12df8d07..49807fb3 100755 --- a/run_fastsurfer.sh +++ b/run_fastsurfer.sh @@ -492,7 +492,11 @@ case $key in ;; # several options that set a variable - --qc_snap) hypvinn_flags+=(--qc_snap) ; cc_flags+=("--qc_output_dir" "qc_snapshots") ;; + --qc_snap) + hypvinn_flags+=(--qc_snap) ; + cc_flags+=(--qc_image "qc_snapshots/callosum.png" --thickness_image "qc_snapshots/callosum.thickness.png" + --cc_html "qc_snapshots/corpus_callosum.html") + ;; ############################################################## # surf-pipeline options From 576508dfbbdb33e8254e5f1c58da30000789c596 Mon Sep 17 00:00:00 2001 From: ClePol Date: Fri, 28 Nov 2025 17:42:46 +0100 Subject: [PATCH 37/68] file renaming, removed unused code, documentation update --- CorpusCallosum/cc_visualization.py | 2 +- CorpusCallosum/data/fsaverage_cc_template.py | 2 +- CorpusCallosum/fastsurfer_cc.py | 28 +++++++----- ...localization_inference.py => inference.py} | 2 +- CorpusCallosum/registration/__init__.py | 0 ...segmentation_inference.py => inference.py} | 2 +- ...int_heuristic.py => endpoint_heuristic.py} | 0 CorpusCallosum/shape/{cc_mesh.py => mesh.py} | 27 +---------- .../shape/{cc_metrics.py => metrics.py} | 0 ...cc_postprocessing.py => postprocessing.py} | 20 +++------ ...gment_contour.py => subsegment_contour.py} | 0 .../shape/{cc_thickness.py => thickness.py} | 0 ...lization_transforms.py => localization.py} | 0 ...entation_transforms.py => segmentation.py} | 0 .../mapping_helpers.py | 0 .../{visualization => utils}/visualization.py | 45 ------------------- CorpusCallosum/visualization/__init__.py | 0 doc/api/CorpusCallosum.localization.rst | 2 +- doc/api/CorpusCallosum.registration.rst | 9 ---- doc/api/CorpusCallosum.segmentation.rst | 2 +- doc/api/CorpusCallosum.shape.rst | 12 ++--- doc/api/CorpusCallosum.transforms.rst | 4 +- doc/api/CorpusCallosum.utils.rst | 2 + doc/api/CorpusCallosum.visualization.rst | 9 ---- doc/api/index.rst | 2 - pyproject.toml | 2 +- run_fastsurfer.sh | 2 + 27 files changed, 43 insertions(+), 131 deletions(-) rename CorpusCallosum/localization/{localization_inference.py => inference.py} (98%) delete mode 100644 CorpusCallosum/registration/__init__.py rename CorpusCallosum/segmentation/{segmentation_inference.py => inference.py} (99%) rename CorpusCallosum/shape/{cc_endpoint_heuristic.py => endpoint_heuristic.py} (100%) rename CorpusCallosum/shape/{cc_mesh.py => mesh.py} (98%) rename CorpusCallosum/shape/{cc_metrics.py => metrics.py} (100%) rename CorpusCallosum/shape/{cc_postprocessing.py => postprocessing.py} (97%) rename CorpusCallosum/shape/{cc_subsegment_contour.py => subsegment_contour.py} (100%) rename CorpusCallosum/shape/{cc_thickness.py => thickness.py} (100%) rename CorpusCallosum/transforms/{localization_transforms.py => localization.py} (100%) rename CorpusCallosum/transforms/{segmentation_transforms.py => segmentation.py} (100%) rename CorpusCallosum/{registration => utils}/mapping_helpers.py (100%) rename CorpusCallosum/{visualization => utils}/visualization.py (86%) delete mode 100644 CorpusCallosum/visualization/__init__.py delete mode 100644 doc/api/CorpusCallosum.registration.rst delete mode 100644 doc/api/CorpusCallosum.visualization.rst diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index d23bf07a..4a18f229 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -8,7 +8,7 @@ from CorpusCallosum.data.constants import FSAVERAGE_DATA_PATH from CorpusCallosum.data.fsaverage_cc_template import load_fsaverage_cc_template from CorpusCallosum.data.read_write import load_fsaverage_data -from CorpusCallosum.shape.cc_mesh import CCMesh +from CorpusCallosum.shape.mesh import CCMesh def make_parser() -> argparse.ArgumentParser: diff --git a/CorpusCallosum/data/fsaverage_cc_template.py b/CorpusCallosum/data/fsaverage_cc_template.py index b331640b..a49e5c61 100644 --- a/CorpusCallosum/data/fsaverage_cc_template.py +++ b/CorpusCallosum/data/fsaverage_cc_template.py @@ -20,7 +20,7 @@ from scipy import ndimage from CorpusCallosum.data import constants -from CorpusCallosum.shape.cc_postprocessing import recon_cc_surf_measure +from CorpusCallosum.shape.postprocessing import recon_cc_surf_measure from FastSurferCNN.utils.brainvolstats import mask_in_array diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index fa190091..90b0ebab 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -42,22 +42,23 @@ load_fsaverage_centroids, load_fsaverage_data, ) -from CorpusCallosum.localization import localization_inference -from CorpusCallosum.registration.mapping_helpers import ( - apply_transform_to_pt, - apply_transform_to_volume, - calc_mapping_to_standard_space, - interpolate_midplane, - map_softlabels_to_orig, -) -from CorpusCallosum.segmentation import segmentation_inference, segmentation_postprocessing -from CorpusCallosum.shape.cc_postprocessing import ( +from CorpusCallosum.localization import inference as localization_inference +from CorpusCallosum.segmentation import inference as segmentation_inference +from CorpusCallosum.segmentation import segmentation_postprocessing +from CorpusCallosum.shape.postprocessing import ( SliceSelection, SubdivisionMethod, check_area_changes, make_subdivision_mask, recon_cc_surf_measures_multi, ) +from CorpusCallosum.utils.mapping_helpers import ( + apply_transform_to_pt, + apply_transform_to_volume, + calc_mapping_to_standard_space, + interpolate_midplane, + map_softlabels_to_orig, +) from FastSurferCNN.data_loader.conform import is_conform from FastSurferCNN.segstats import HelpFormatter from FastSurferCNN.utils import logging @@ -669,10 +670,13 @@ def main( orig = cast(nib.analyze.SpatialImage, nib.load(sd.conf_name)) # 5 mm around the midplane (making sure to get rl by as_closest_canonical) - slices_to_analyze = int(np.ceil(5 / nib.as_closest_canonical(orig).header.get_zooms()[0])) // 2 * 2 + 1 + vox_size = nib.as_closest_canonical(orig).header.get_zooms()[0] + slices_to_analyze = int(np.ceil(5 / vox_size)) + if slices_to_analyze % 2 == 0: + slices_to_analyze += 1 logger.info( - f"Segmenting {slices_to_analyze} slices (5 mm width at {orig.header.get_zooms()[0]} mm resolution, " + f"Segmenting {slices_to_analyze} slices (5 mm width at {vox_size} mm resolution, " "center around the mid-sagittal plane)" ) diff --git a/CorpusCallosum/localization/localization_inference.py b/CorpusCallosum/localization/inference.py similarity index 98% rename from CorpusCallosum/localization/localization_inference.py rename to CorpusCallosum/localization/inference.py index 20c6e0f1..0d6e5716 100644 --- a/CorpusCallosum/localization/localization_inference.py +++ b/CorpusCallosum/localization/inference.py @@ -20,7 +20,7 @@ from monai.networks.nets import DenseNet from numpy import typing as npt -from CorpusCallosum.transforms.localization_transforms import CropAroundACPCFixedSize +from CorpusCallosum.transforms.localization import CropAroundACPCFixedSize from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults from FastSurferCNN.download_checkpoints import main as download_checkpoints diff --git a/CorpusCallosum/registration/__init__.py b/CorpusCallosum/registration/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/CorpusCallosum/segmentation/segmentation_inference.py b/CorpusCallosum/segmentation/inference.py similarity index 99% rename from CorpusCallosum/segmentation/segmentation_inference.py rename to CorpusCallosum/segmentation/inference.py index 692f0039..70c63d86 100644 --- a/CorpusCallosum/segmentation/segmentation_inference.py +++ b/CorpusCallosum/segmentation/inference.py @@ -21,7 +21,7 @@ from numpy import typing as npt from CorpusCallosum.data import constants -from CorpusCallosum.transforms.segmentation_transforms import CropAroundACPC +from CorpusCallosum.transforms.segmentation import CropAroundACPC from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults from FastSurferCNN.download_checkpoints import main as download_checkpoints diff --git a/CorpusCallosum/shape/cc_endpoint_heuristic.py b/CorpusCallosum/shape/endpoint_heuristic.py similarity index 100% rename from CorpusCallosum/shape/cc_endpoint_heuristic.py rename to CorpusCallosum/shape/endpoint_heuristic.py diff --git a/CorpusCallosum/shape/cc_mesh.py b/CorpusCallosum/shape/mesh.py similarity index 98% rename from CorpusCallosum/shape/cc_mesh.py rename to CorpusCallosum/shape/mesh.py index 3be0f82a..49761130 100644 --- a/CorpusCallosum/shape/cc_mesh.py +++ b/CorpusCallosum/shape/mesh.py @@ -26,8 +26,8 @@ from scipy.ndimage import gaussian_filter1d import FastSurferCNN.utils.logging as logging -from CorpusCallosum.shape.cc_endpoint_heuristic import smooth_contour -from CorpusCallosum.shape.cc_thickness import make_mesh_from_contour +from CorpusCallosum.shape.endpoint_heuristic import smooth_contour +from CorpusCallosum.shape.thickness import make_mesh_from_contour from FastSurferCNN.utils.common import suppress_stdout try: @@ -996,18 +996,6 @@ def plot_cc_contour_with_levelsets( plot_values = np.array(self.thickness_values[contour_idx][~np.isnan(self.thickness_values[contour_idx])][:100])[ ::-1 ] - # double plot values with linear interpolation - - # Create bar plot of thickness values - # fig, ax = plt.subplots(figsize=(10, 4)) - # ax.bar(range(len(plot_values)), plot_values) - # ax.set_xlabel('Point Index') - # ax.set_ylabel('Thickness (mm)') - # ax.set_title('Thickness Distribution') - # ax.set_ylim(0, 0.06) - # ax.invert_xaxis() - # plt.tight_layout() - # plt.show() points, trias = make_mesh_from_contour(self.contours[contour_idx], max_volume=0.5, min_angle=25, verbose=False) @@ -1146,22 +1134,11 @@ def plot_cc_contour_with_levelsets( # Plot the outside contour on top for clear boundary plt.plot(outside_contour[0], outside_contour[1], "k-", linewidth=2, label="CC Contour", transform=transform) - # plot levelpaths - # for i, path in enumerate(levelpaths): - # plt.plot(path[:,0], path[:,1], 'k--', linewidth=1, alpha=0.2, transform=transform) - # plot midline - # if midline_equidistant is not None: - # midline_x, midline_y = zip(*midline_equidistant) - # plt.plot(midline_x, midline_y, 'k--', linewidth=2, transform=transform, alpha=0.2) - plt.axis("equal") plt.title(title, fontsize=14, fontweight="bold") # plt.legend(loc='best') plt.gca().invert_xaxis() plt.axis("off") - # plt.tight_layout() - # plt.ylim(-105, -75) - # plt.xlim(181, 101) if save_path is not None: self.__make_parent_folder(save_path) plt.savefig(save_path, dpi=300) diff --git a/CorpusCallosum/shape/cc_metrics.py b/CorpusCallosum/shape/metrics.py similarity index 100% rename from CorpusCallosum/shape/cc_metrics.py rename to CorpusCallosum/shape/metrics.py diff --git a/CorpusCallosum/shape/cc_postprocessing.py b/CorpusCallosum/shape/postprocessing.py similarity index 97% rename from CorpusCallosum/shape/cc_postprocessing.py rename to CorpusCallosum/shape/postprocessing.py index 20bc4988..3b292c52 100644 --- a/CorpusCallosum/shape/cc_postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -20,18 +20,18 @@ import FastSurferCNN.utils.logging as logging from CorpusCallosum.data.constants import CC_LABEL, FSAVERAGE_MIDDLE, SUBSEGMENT_LABELS -from CorpusCallosum.shape.cc_endpoint_heuristic import get_endpoints -from CorpusCallosum.shape.cc_mesh import CCMesh -from CorpusCallosum.shape.cc_metrics import calculate_cc_index -from CorpusCallosum.shape.cc_subsegment_contour import ( +from CorpusCallosum.shape.endpoint_heuristic import get_endpoints +from CorpusCallosum.shape.mesh import CCMesh +from CorpusCallosum.shape.metrics import calculate_cc_index +from CorpusCallosum.shape.subsegment_contour import ( get_primary_eigenvector, hampel_subdivide_contour, subdivide_contour, subsegment_midline_orthogonal, transform_to_acpc_standard, ) -from CorpusCallosum.shape.cc_thickness import cc_thickness, convert_to_ras -from CorpusCallosum.visualization.visualization import plot_contours +from CorpusCallosum.shape.thickness import cc_thickness, convert_to_ras +from CorpusCallosum.utils.visualization import plot_contours from FastSurferCNN.utils.common import SubjectDirectory, suppress_stdout, update_docstring from FastSurferCNN.utils.parallel import process_executor, thread_executor @@ -580,14 +580,6 @@ def make_subdivision_mask( # All points to the right of this line belong to the next segment or beyond subdivision_mask[points_right_of_line] = subsegment_labels_anterior_posterior[segment_idx + 1] - - # Debug visualization (optional) - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots(figsize=(10, 8)) - # ax.imshow(subdivision_mask, cmap="tab10") - # ax.plot([line_start[0], line_end[0]], [line_start[1], line_end[1]], "r-", linewidth=2) - # ax.set_title(f"After subdivision line {segment_idx}") - # plt.show() return subdivision_mask diff --git a/CorpusCallosum/shape/cc_subsegment_contour.py b/CorpusCallosum/shape/subsegment_contour.py similarity index 100% rename from CorpusCallosum/shape/cc_subsegment_contour.py rename to CorpusCallosum/shape/subsegment_contour.py diff --git a/CorpusCallosum/shape/cc_thickness.py b/CorpusCallosum/shape/thickness.py similarity index 100% rename from CorpusCallosum/shape/cc_thickness.py rename to CorpusCallosum/shape/thickness.py diff --git a/CorpusCallosum/transforms/localization_transforms.py b/CorpusCallosum/transforms/localization.py similarity index 100% rename from CorpusCallosum/transforms/localization_transforms.py rename to CorpusCallosum/transforms/localization.py diff --git a/CorpusCallosum/transforms/segmentation_transforms.py b/CorpusCallosum/transforms/segmentation.py similarity index 100% rename from CorpusCallosum/transforms/segmentation_transforms.py rename to CorpusCallosum/transforms/segmentation.py diff --git a/CorpusCallosum/registration/mapping_helpers.py b/CorpusCallosum/utils/mapping_helpers.py similarity index 100% rename from CorpusCallosum/registration/mapping_helpers.py rename to CorpusCallosum/utils/mapping_helpers.py diff --git a/CorpusCallosum/visualization/visualization.py b/CorpusCallosum/utils/visualization.py similarity index 86% rename from CorpusCallosum/visualization/visualization.py rename to CorpusCallosum/utils/visualization.py index db2feea2..285e0d84 100644 --- a/CorpusCallosum/visualization/visualization.py +++ b/CorpusCallosum/utils/visualization.py @@ -213,48 +213,3 @@ def plot_contours( Path(output_path).parent.mkdir(parents=True, exist_ok=True) plt.savefig(output_path, dpi=300, bbox_inches="tight") - - -def plot_midplane(grid_orig: np.ndarray, orig: np.ndarray) -> None: - """Create a 3D visualization of grid points in original image space. - - Parameters - ---------- - grid_orig : np.ndarray - Grid points in original space, shape (3, N). - orig : np.ndarray - Original image for dimension reference. - - Notes - ----- - The function: - 1. Creates a 3D scatter plot of grid points - 2. Samples every 40th point to avoid overcrowding - 3. Sets axis limits based on original image dimensions - 4. Shows the plot interactively - """ - # Create a figure showing grid points in original space - - # Create 3D plot - fig = plt.figure(figsize=(10, 10)) - ax = fig.add_subplot(111, projection="3d") - - # Plot every 10th point to avoid overcrowding - sample_idx = np.arange(0, grid_orig.shape[1], 40) - ax.scatter(*grid_orig[:3, sample_idx], c="r", alpha=0.1, marker=".") - - # Set labels - ax.set_xlabel("X") - ax.set_ylabel("Y") - ax.set_zlabel("Z") - ax.set_title("Grid Points in Original Image Space") - - # Set axis limits to image dimensions - ax.set_xlim(0, orig.shape[0]) - ax.set_ylim(0, orig.shape[1]) - ax.set_zlim(0, orig.shape[2]) - - # Save plot - plt.show() - # plt.savefig('grid_points.png') - # plt.close() diff --git a/CorpusCallosum/visualization/__init__.py b/CorpusCallosum/visualization/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/doc/api/CorpusCallosum.localization.rst b/doc/api/CorpusCallosum.localization.rst index 142ad093..9c6c3b40 100644 --- a/doc/api/CorpusCallosum.localization.rst +++ b/doc/api/CorpusCallosum.localization.rst @@ -6,4 +6,4 @@ CorpusCallosum.localization .. autosummary:: :toctree: generated/ - localization_inference + inference diff --git a/doc/api/CorpusCallosum.registration.rst b/doc/api/CorpusCallosum.registration.rst deleted file mode 100644 index add8daa1..00000000 --- a/doc/api/CorpusCallosum.registration.rst +++ /dev/null @@ -1,9 +0,0 @@ -CorpusCallosum.registration -============================ - -.. currentmodule:: CorpusCallosum.registration - -.. autosummary:: - :toctree: generated/ - - mapping_helpers diff --git a/doc/api/CorpusCallosum.segmentation.rst b/doc/api/CorpusCallosum.segmentation.rst index 291c14e3..0269688b 100644 --- a/doc/api/CorpusCallosum.segmentation.rst +++ b/doc/api/CorpusCallosum.segmentation.rst @@ -6,5 +6,5 @@ CorpusCallosum.segmentation .. autosummary:: :toctree: generated/ - segmentation_inference + inference segmentation_postprocessing diff --git a/doc/api/CorpusCallosum.shape.rst b/doc/api/CorpusCallosum.shape.rst index 5fcccd7b..cd89aedc 100644 --- a/doc/api/CorpusCallosum.shape.rst +++ b/doc/api/CorpusCallosum.shape.rst @@ -6,9 +6,9 @@ CorpusCallosum.shape .. autosummary:: :toctree: generated/ - cc_postprocessing - cc_mesh - cc_metrics - cc_thickness - cc_subsegment_contour - cc_endpoint_heuristic + postprocessing + mesh + metrics + thickness + subsegment_contour + endpoint_heuristic diff --git a/doc/api/CorpusCallosum.transforms.rst b/doc/api/CorpusCallosum.transforms.rst index 44b63315..14756a92 100644 --- a/doc/api/CorpusCallosum.transforms.rst +++ b/doc/api/CorpusCallosum.transforms.rst @@ -6,5 +6,5 @@ CorpusCallosum.transforms .. autosummary:: :toctree: generated/ - localization_transforms - segmentation_transforms + localization + segmentation diff --git a/doc/api/CorpusCallosum.utils.rst b/doc/api/CorpusCallosum.utils.rst index f0490a10..a6595d5b 100644 --- a/doc/api/CorpusCallosum.utils.rst +++ b/doc/api/CorpusCallosum.utils.rst @@ -7,3 +7,5 @@ CorpusCallosum.utils :toctree: generated/ checkpoint + mapping_helpers + visualization diff --git a/doc/api/CorpusCallosum.visualization.rst b/doc/api/CorpusCallosum.visualization.rst deleted file mode 100644 index f6801dec..00000000 --- a/doc/api/CorpusCallosum.visualization.rst +++ /dev/null @@ -1,9 +0,0 @@ -CorpusCallosum.visualization -============================== - -.. currentmodule:: CorpusCallosum.visualization - -.. autosummary:: - :toctree: generated/ - - visualization diff --git a/doc/api/index.rst b/doc/api/index.rst index ef022059..fd606a8b 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -19,12 +19,10 @@ FastSurfer API CorpusCallosum.rst CorpusCallosum.data.rst CorpusCallosum.localization.rst - CorpusCallosum.registration.rst CorpusCallosum.segmentation.rst CorpusCallosum.shape.rst CorpusCallosum.transforms.rst CorpusCallosum.utils.rst - CorpusCallosum.visualization.rst HypVINN.rst HypVINN.dataloader.rst HypVINN.models.rst diff --git a/pyproject.toml b/pyproject.toml index 7d05072e..6c116e0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,13 +53,13 @@ dependencies = [ 'monai>=1.4.0', 'meshpy>=2025.1.1', 'pyrr>=0.10.3', - 'whippersnappy>=1.3.1', 'pip>=25.0', ] [project.optional-dependencies] doc = [ 'furo!=2023.8.17', + 'whippersnappy>=1.3.1', 'memory-profiler', 'myst-parser', 'numpydoc', diff --git a/run_fastsurfer.sh b/run_fastsurfer.sh index 49807fb3..6727bddd 100755 --- a/run_fastsurfer.sh +++ b/run_fastsurfer.sh @@ -469,6 +469,8 @@ case $key in # corupus callosum module options #============================================================= --no_cc) run_cc_module="0" ;; + # TODO: remove this dev flag + --upright) cc_flags+=("--upright_volume" "mri/upright.mgz") ;; # cereb module options #============================================================= From 22768a131cbb60085b0dff097fd403711a0c497d Mon Sep 17 00:00:00 2001 From: ClePol Date: Thu, 4 Dec 2025 15:00:13 +0100 Subject: [PATCH 38/68] fixed commandline texts, parameter and absolute paths --- CorpusCallosum/fastsurfer_cc.py | 25 +++++++----- .../segmentation_postprocessing.py | 10 +++-- CorpusCallosum/shape/mesh.py | 40 +++++++++++++------ 3 files changed, 49 insertions(+), 26 deletions(-) diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 90b0ebab..5262fabd 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -138,9 +138,9 @@ def _set_help_sid(action): "--subdivisions", type=float, metavar="FRAC", - nargs=4, default=_FixFloatFormattingList([1 / 6, 1 / 2, 2 / 3, 3 / 4], ".3f"), - help="List of FOUR subdivision fractions for the corpus callosum subsegmentation.", + help="List of subdivision fractions for the corpus callosum subsegmentation" + "(allows for an arbitrary number of fractions).", ) parser.add_argument( "--subdivision_method", @@ -333,16 +333,22 @@ def options_parse() -> argparse.Namespace: if not args.aseg_name: args.aseg_name = args.subject_dir / DEFAULT_INPUT_PATHS["aseg_name"] + else: + print("WARNING: Not providing subject_dir leads to discarding of files with relative paths!") + args.subject_dir = Path("/dev/null/no-subject-dir-set") all_paths = ("segmentation", "segmentation_in_orig", "cc_measures", "upright_lta", "orient_volume_lta", "cc_surf", - "softlabels_cc", "softlabels_fn", "softlabels_background", "cc_mid_measures", "cc_thickness_overlay", + "softlabels_cc", "softlabels_fn", "softlabels_background", "cc_mid_measures", "thickness_overlay", "qc_image", "thickness_image", "cc_html") # Create parent directories for all output paths for path_name in all_paths: path: Path | None = getattr(args, path_name, None) if isinstance(path, Path) and not args.subject_dir and not path.is_absolute(): - parser.error(f"Must specify --sd and --sid if any path is relative but {path} for {path_name} is relative.") + # set path to none in arguments + # FIXME: Should there be a check, if a specific "path_name" is mandatory? + print(f"WARNING: Not writing {path_name}, because --sd and --sid are not specified and {path} is relative.") + setattr(args, path_name, None) return args @@ -793,11 +799,12 @@ def main( fsaverage_middle=FSAVERAGE_MIDDLE, cc_subseg_midslice=cc_subseg_midslice, )) - io_futures.append(thread_executor().submit( - nib.save, - nib.MGHImage(cc_fn_seg_labels, seg_affine, orig.header), - sd.filename_by_attribute("cc_segmentation"), - )) + if sd.has_attribute("cc_segmentation"): + io_futures.append(thread_executor().submit( + nib.save, + nib.MGHImage(cc_fn_seg_labels, seg_affine, orig.header), + sd.filename_by_attribute("cc_segmentation"), + )) METRICS = [ "areas", diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index 51cbffe3..265ae634 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -335,11 +335,13 @@ def get_cc_volume_voxel( desired_width_vox = desired_width_mm / voxel_size[0] fraction_of_voxel_at_edge = (desired_width_vox % 1) / 2 - if fraction_of_voxel_at_edge > 0: - desired_width_vox = int(np.floor(desired_width_vox) + 1) - desired_width_vox = desired_width_vox + 1 if desired_width_vox % 2 == 0 else desired_width_vox + if fraction_of_voxel_at_edge > 0: + # make sure the assumentation is correct that the CC mask has an odd number of voxels + # and the leftmost and rightmost voxels are the edges at the desired width + cc_width_vox = int(np.floor(desired_width_vox) + 1) + cc_width_vox = cc_width_vox + 1 if cc_width_vox % 2 == 0 else cc_width_vox - assert cc_mask.shape[0] == desired_width_vox, (f"CC mask should have {desired_width_vox} voxels, " + assert cc_mask.shape[0] == cc_width_vox, (f"CC mask should have {cc_width_vox} voxels, " f"but has {cc_mask.shape[0]}") left_partial_volume = np.sum(cc_mask[0]) * voxel_volume * fraction_of_voxel_at_edge diff --git a/CorpusCallosum/shape/mesh.py b/CorpusCallosum/shape/mesh.py index 49761130..4e5f2e9c 100644 --- a/CorpusCallosum/shape/mesh.py +++ b/CorpusCallosum/shape/mesh.py @@ -968,6 +968,7 @@ def plot_cc_contour_with_levelsets( title: str | None = None, save_path: str | None = None, colorbar: bool = True, + mode: str = "p-value", ) -> matplotlib.figure.Figure: """Plot a contour with levelset visualization. @@ -986,14 +987,15 @@ def plot_cc_contour_with_levelsets( Path to save the plot. If None, displays interactively, by default None. colorbar : bool, optional Whether to show the colorbar, by default True. - + mode: str, optional + Mode of the plot, by default "p-value". Can be "p-value" or "icc". Returns ------- matplotlib.figure.Figure The created figure object. """ - plot_values = np.array(self.thickness_values[contour_idx][~np.isnan(self.thickness_values[contour_idx])][:100])[ + plot_values = np.array(self.thickness_values[contour_idx][~np.isnan(self.thickness_values[contour_idx])])[ ::-1 ] @@ -1002,7 +1004,7 @@ def plot_cc_contour_with_levelsets( # make points 3D by adding zero points = np.column_stack([points, np.zeros(len(points))]) - levelpaths, _ = self._create_levelpaths(contour_idx, points, trias, num_points=99) + levelpaths, _ = self._create_levelpaths(contour_idx, points, trias, num_points=len(plot_values)-2) outside_contour = self.contours[contour_idx].T @@ -1079,9 +1081,16 @@ def plot_cc_contour_with_levelsets( # Apply the mask to only show values inside the contour masked_values = np.where(mask, grid_values, np.nan) - # Sample colormaps (e.g., 'binary' and 'gist_heat_r') - colors1 = plt.cm.binary([0.4] * 128) - colors2 = plt.cm.hot(np.linspace(0.8, 0.1, 128)) + + if mode == "p-value": + # Sample colormaps + colors1 = plt.cm.binary([0.4] * 128) + colors2 = plt.cm.hot(np.linspace(0.8, 0.1, 128)) + + + else: + colors1 = plt.cm.Blues(np.linspace(0, 1, 128)) + colors2 = plt.cm.binary([0.4] * 128) # Combine the color samples colors = np.vstack((colors2, colors1)) @@ -1105,7 +1114,7 @@ def plot_cc_contour_with_levelsets( alpha=1, interpolation="bilinear", vmin=0, - vmax=0.10, + vmax=0.10 if mode == "p-value" else 1, transform=transform, ) @@ -1117,19 +1126,24 @@ def plot_cc_contour_with_levelsets( alpha=1, interpolation="bilinear", vmin=0, - vmax=0.10, + vmax=0.10 if mode == "p-value" else 1, # norm=LogNorm(vmin=1e-3, vmax=0.1), # Set minimum to avoid log(0) transform=transform, ) + + if colorbar: # Add a colorbar cbar = plt.colorbar(aspect=10) - cbar.ax.set_ylim(0.001, 0.054) - cbar.ax.set_yticks([0.0, 0.01, 0.02, 0.03, 0.04, 0.05]) - # cbar.ax.set_yticks([0.001, 0.01, 0.05]) - # cbar.ax.set_yticklabels(['0.001', '0.01', '0.05']) - cbar.set_label("p-value (log scale)") + if mode == "p-value": + cbar.ax.set_ylim(0.001, 0.054) + cbar.ax.set_yticks([0.0, 0.01, 0.02, 0.03, 0.04, 0.05]) + cbar.set_label("p-value (log scale)") + elif mode == "icc": + cbar.ax.set_ylim(0, 1) + cbar.ax.set_yticks([0, 0.25, 0.5, 0.75, 1]) + cbar.ax.set_label("Intraclass correlation coefficient") # Plot the outside contour on top for clear boundary plt.plot(outside_contour[0], outside_contour[1], "k-", linewidth=2, label="CC Contour", transform=transform) From 4ec8ec550e2a4a45cd35ce8df413363cf7a7706c Mon Sep 17 00:00:00 2001 From: ClePol Date: Thu, 4 Dec 2025 17:25:33 +0100 Subject: [PATCH 39/68] Rewrite of CCIndex, fixed middle slice selection argument, helptext --- CorpusCallosum/fastsurfer_cc.py | 9 +- CorpusCallosum/shape/metrics.py | 348 ++++++++++++++++++++----- CorpusCallosum/shape/postprocessing.py | 6 +- 3 files changed, 296 insertions(+), 67 deletions(-) diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 5262fabd..0bdf17c5 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -139,8 +139,9 @@ def _set_help_sid(action): type=float, metavar="FRAC", default=_FixFloatFormattingList([1 / 6, 1 / 2, 2 / 3, 3 / 4], ".3f"), - help="List of subdivision fractions for the corpus callosum subsegmentation" - "(allows for an arbitrary number of fractions).", + help="List of subdivision fractions for the corpus callosum subsegmentation." + "The method allows for an arbitrary number of fractions." + "By default it uses following Hofer-Frahms convention." ) parser.add_argument( "--subdivision_method", @@ -939,8 +940,8 @@ def main( subject_dir=options.subject_dir, slice_selection=options.slice_selection, num_thickness_points=options.num_thickness_points, - subdivisions=list(options.subdivisions), # default value is type _fmt_list (does not pickle) - subdivision_method=str(options.subdivision_method), # default value is type do not print (does not pickle) + subdivisions=list(options.subdivisions), + subdivision_method=str(options.subdivision_method), contour_smoothing=options.contour_smoothing, save_template_dir=options.save_template_dir, device=options.device, diff --git a/CorpusCallosum/shape/metrics.py b/CorpusCallosum/shape/metrics.py index 7ae55f52..921c819f 100644 --- a/CorpusCallosum/shape/metrics.py +++ b/CorpusCallosum/shape/metrics.py @@ -18,85 +18,311 @@ logger = logging.get_logger(__name__) -def calculate_cc_index(cc_contour: np.ndarray) -> float: - """Calculate CC index based on three perpendicular measurements. + +# TODO: we could make this more robust by standardizing orientation with AC/PC and smoothing the contour + +def _line_segment_intersection( + line_point: np.ndarray, + line_dir: np.ndarray, + seg_start: np.ndarray, + seg_end: np.ndarray, + tol: float = 1e-10, +) -> np.ndarray | None: + """Compute intersection between an infinite line and a line segment. + + Uses the parametric form: + - Line: P = line_point + t * line_dir + - Segment: Q = seg_start + s * (seg_end - seg_start), where s ∈ [0, 1] + + Parameters + ---------- + line_point : np.ndarray + A point on the infinite line, shape (2,). + line_dir : np.ndarray + Direction vector of the line, shape (2,). + seg_start : np.ndarray + Start point of the segment, shape (2,). + seg_end : np.ndarray + End point of the segment, shape (2,). + tol : float + Tolerance for numerical comparisons. + + Returns + ------- + np.ndarray | None + Intersection point as shape (2,) array, or None if no intersection. + """ + seg_dir = seg_end - seg_start + + # Build the linear system: [line_dir, -seg_dir] @ [t, s].T = seg_start - line_point + # Matrix A = [[line_dir[0], -seg_dir[0]], [line_dir[1], -seg_dir[1]]] + A = np.array([[line_dir[0], -seg_dir[0]], + [line_dir[1], -seg_dir[1]]]) + b = seg_start - line_point + + # Check if lines are parallel (determinant ≈ 0) + det = A[0, 0] * A[1, 1] - A[0, 1] * A[1, 0] + if abs(det) < tol: + return None + + # Solve for t and s using Cramer's rule (faster than linalg.solve for 2x2) + t = (b[0] * A[1, 1] - b[1] * A[0, 1]) / det + s = (A[0, 0] * b[1] - A[1, 0] * b[0]) / det + + # Check if intersection is within the segment [0, 1] + if -tol <= s <= 1.0 + tol: + return line_point + t * line_dir + return None + + +def get_intersections( + contour: np.ndarray, start_point: np.ndarray, direction: np.ndarray +) -> np.ndarray: + """Find intersection points between an infinite line and a closed contour. + + Parameters + ---------- + contour : np.ndarray + Array of shape (2, N) containing contour points in ACPC space. + start_point : np.ndarray + A point on the line, shape (2,). + direction : np.ndarray + Direction vector of the line, shape (2,). + + Returns + ------- + np.ndarray + Array of shape (M, 2) containing intersection points, sorted along the direction. + """ + start_point = np.asarray(start_point, dtype=float) + direction = np.asarray(direction, dtype=float) + + # Normalize direction + dir_norm = np.linalg.norm(direction) + if dir_norm < 1e-10: + return np.empty((0, 2)) + direction = direction / dir_norm + + n_points = contour.shape[1] + intersections = [] + + # Check intersection with each segment of the closed contour + for i in range(n_points): + seg_start = contour[:, i] + seg_end = contour[:, (i + 1) % n_points] # Wrap around to close the contour + + intersection = _line_segment_intersection( + start_point, direction, seg_start, seg_end + ) + if intersection is not None: + intersections.append(intersection) + + if not intersections: + return np.empty((0, 2)) + + points = np.array(intersections) + + # Remove duplicate points (can occur at contour vertices) + if len(points) > 1: + # Project onto line direction and find unique points + projections = np.dot(points - start_point, direction) + # Sort and remove duplicates within tolerance + sorted_idx = np.argsort(projections) + points = points[sorted_idx] + projections = projections[sorted_idx] + + # Keep points that are sufficiently far apart + mask = np.ones(len(points), dtype=bool) + for i in range(1, len(points)): + if abs(projections[i] - projections[i - 1]) < 1e-8: + mask[i] = False + points = points[mask] + + return points + + +def calculate_cc_index(cc_contour: np.ndarray, plot: bool = False) -> float: + """Calculate CC index based on three thickness measurements. + + The AP line intersects the contour 4 times. The measurements are: + - Anterior thickness: distance between intersection points 1 and 2 + - Posterior thickness: distance between intersection points 3 and 4 + - Middle thickness: perpendicular line through midpoint of AP line + + The CC index is: (anterior + posterior + middle) / AP_length Parameters ---------- cc_contour : np.ndarray Array of shape (2, N) containing contour points in ACPC space. + plot : bool, optional + Whether to generate a debug plot. Default is True. Returns ------- cc_index : float The CC index, which is the sum of thicknesses at three measurement points divided by AP length. """ - # Get anterior and posterior points + # Get anterior and posterior points (extremes along x-axis) anterior_idx = np.argmin(cc_contour[0]) # Leftmost point posterior_idx = np.argmax(cc_contour[0]) # Rightmost point - # Get the longest line (anterior to posterior) - ap_line = cc_contour[:, posterior_idx] - cc_contour[:, anterior_idx] - ap_length = np.linalg.norm(ap_line) - ap_unit = np.array([-ap_line[1], ap_line[0]]) / ap_length - - # Get midpoint of AP line - midpoint = cc_contour[:, anterior_idx] + (ap_line / 2) - - # Get perpendicular direction - - # Get intersection points with contour for each measurement line - def get_intersections(start_point: np.ndarray, direction: np.ndarray) -> np.ndarray: - """Find intersection points between a line and the contour. - - Parameters - ---------- - start_point : np.ndarray - Starting point of the line, shape (2,). - direction : np.ndarray - Direction vector of the line, shape (2,). - - Returns - ------- - np.ndarray - Array of shape (N, 2) containing intersection points. - """ - # Get all points above and below the line - points = cc_contour.T - start_point[None, :] - dots = np.dot(points, direction) - signs = np.sign(dots) - sign_changes = np.where(np.diff(signs))[0] - - # Linear interpolation between points - t = -dots[sign_changes] / (dots[sign_changes + 1] - dots[sign_changes]) - return cc_contour[:, sign_changes] + t * (cc_contour[:, sign_changes + 1] - cc_contour[:, sign_changes]) - - # Get three measurements - most_anterior_pt = cc_contour[:, anterior_idx] - perpendicular_unit = np.array([-ap_unit[1], ap_unit[0]]) - - anterior_intersections = get_intersections(most_anterior_pt - 10 * perpendicular_unit, ap_unit) - - # sort by x - anterior_intersections = anterior_intersections[np.argsort(anterior_intersections[:, 0])] - - middle_ints = get_intersections(midpoint, perpendicular_unit) - - if len(middle_ints) != 2: - logger.warning( - f"The perpendicular line should intersect the contour twice, " - f"but it intersects {len(middle_ints)} times" + anterior_pt = cc_contour[:, anterior_idx] + posterior_pt = cc_contour[:, posterior_idx] + + # AP line vector and properties + ap_vector = posterior_pt - anterior_pt + ap_length = np.linalg.norm(ap_vector) + ap_unit = ap_vector / ap_length + + # Perpendicular direction (90 degrees rotated) + perp_unit = np.array([-ap_unit[1], ap_unit[0]]) + + # Find where AP line intersects the contour (should be 4 points) + ap_intersections = get_intersections( + contour=cc_contour, start_point=anterior_pt, direction=ap_unit + ) + + if len(ap_intersections) != 4: + logger.error( + f"AP line should intersect contour exactly 4 times, " + f"but found {len(ap_intersections)} intersections" ) + return 0.0 + + # Measurement 1: anterior thickness (between intersection points 1 and 2) + anterior_thickness = np.linalg.norm(ap_intersections[0] - ap_intersections[1]) + + # Measurement 2: posterior thickness (between intersection points 3 and 4) + posterior_thickness = np.linalg.norm(ap_intersections[2] - ap_intersections[3]) - # plt.close() + # AP distance is between outermost intersection points (1 and 4) + ap_distance = np.linalg.norm(ap_intersections[0] - ap_intersections[3]) - # calculate index - ap_distance = np.linalg.norm(anterior_intersections[0] - anterior_intersections[-1]) - anterior_distance = np.linalg.norm(anterior_intersections[0] - anterior_intersections[1]) - posterior_distance = np.linalg.norm(anterior_intersections[-1] - anterior_intersections[-2]) - top_distance = np.linalg.norm(middle_ints[0] - middle_ints[1]) + # Midpoint of AP line (between points 1 and 4, or between anterior and posterior extremes) + midpoint = (ap_intersections[0] + ap_intersections[3]) / 2 - cc_index = (anterior_distance + posterior_distance + top_distance) / ap_distance + # Measurement 3: perpendicular line through midpoint + middle_intersections = get_intersections( + contour=cc_contour, start_point=midpoint, direction=perp_unit + ) + middle_thickness = np.linalg.norm(middle_intersections[0] - middle_intersections[-1]) + + # Calculate CC index + cc_index = (anterior_thickness + posterior_thickness + middle_thickness) / ap_distance + + if plot: + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(8, 6)) + plot_cc_index_calculation( + ax, + cc_contour, + anterior_idx, + posterior_idx, + ap_intersections, + middle_intersections, + midpoint, + ) + ax.legend() + plt.show() return cc_index + + +def plot_cc_index_calculation( + ax, + cc_contour: np.ndarray, + anterior_idx: int, + posterior_idx: int, + ap_intersections: np.ndarray, + middle_intersections: np.ndarray, + midpoint: np.ndarray, +) -> None: + """Plot the CC index measurements. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axes to plot on. + cc_contour : np.ndarray + Array of shape (2, N) containing contour points in ACPC space. + anterior_idx : int + Index of the anterior point on the contour. + posterior_idx : int + Index of the posterior point on the contour. + ap_intersections : np.ndarray + Array of shape (4, 2) containing the 4 intersection points of the AP line with the contour. + middle_intersections : np.ndarray + Array of shape (2, 2) containing middle perpendicular intersection points. + midpoint : np.ndarray + Array of shape (2,) containing the midpoint of the AP line. + """ + from matplotlib.patches import PathPatch + from matplotlib.path import Path + + # Plot the CC contour (closed) + ax.plot(cc_contour[0], cc_contour[1], "k-", linewidth=1) + ax.plot( + [cc_contour[0, -1], cc_contour[0, 0]], + [cc_contour[1, -1], cc_contour[1, 0]], + "k-", + linewidth=1, + ) + + # Plot AP line through all 4 intersection points + ax.plot( + [ap_intersections[0, 0], ap_intersections[3, 0]], + [ap_intersections[0, 1], ap_intersections[3, 1]], + "r--", + linewidth=1, + label="AP line", + ) + + # Mark all 4 intersection points + for i, pt in enumerate(ap_intersections): + ax.scatter([pt[0]], [pt[1]], s=40, zorder=5) + ax.annotate(f"{i+1}", (pt[0], pt[1]), textcoords="offset points", + xytext=(5, 5), fontsize=10) + + # Plot anterior thickness (points 1-2) + ax.plot( + [ap_intersections[0, 0], ap_intersections[1, 0]], + [ap_intersections[0, 1], ap_intersections[1, 1]], + "b-", + linewidth=3, + label="Anterior thickness (1-2)", + ) + + # Plot posterior thickness (points 3-4) + ax.plot( + [ap_intersections[2, 0], ap_intersections[3, 0]], + [ap_intersections[2, 1], ap_intersections[3, 1]], + "c-", + linewidth=3, + label="Posterior thickness (3-4)", + ) + + # Plot middle thickness measurement (perpendicular) + ax.plot( + [middle_intersections[0, 0], middle_intersections[-1, 0]], + [middle_intersections[0, 1], middle_intersections[-1, 1]], + "g-", + linewidth=3, + label="Middle thickness", + ) + + # Mark midpoint + ax.scatter([midpoint[0]], [midpoint[1]], color="red", s=50, zorder=5, + marker="x", label="Midpoint") + + ax.set_aspect("equal") + + # Fill the contour with gray + contour_path = Path(cc_contour.T) + patch = PathPatch(contour_path, facecolor="gray", alpha=0.2, edgecolor=None) + ax.add_patch(patch) + + ax.invert_xaxis() + ax.axis("off") diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index 3b292c52..f3196af2 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -148,6 +148,7 @@ def recon_cc_surf_measures_multi( num_slices = 1 # Process only the middle slice slice_iterator = [segmentation.shape[0] // 2] + start_slice = segmentation.shape[0] // 2 elif slice_selection == "all": num_slices = segmentation.shape[0] start_slice = 0 @@ -156,6 +157,7 @@ def recon_cc_surf_measures_multi( else: # specific slice number num_slices = 1 slice_iterator = [int(slice_selection)] + start_slice = int(slice_selection) it_affine = map(partial(create_slice_affine, upright_affine, fsaverage_middle=FSAVERAGE_MIDDLE), slice_iterator) @@ -168,7 +170,7 @@ def _yield_iterator(): for _slice_idx in slice_iterator: try: yield _slice_idx, *next(iterator) - except ValueError as e: + except Exception as e: logger.error(f"Slice {_slice_idx} failed with error: {e}") logger.exception(e) except StopIteration: @@ -179,7 +181,7 @@ def _yield_iterator(): # insert progress = f" ({i+1} of {num_slices})" if num_slices > 1 else "" logger.info(f"Calculating CC measurements for slice {slice_idx+1}{progress}") - cc_mesh.add_contour(slice_idx, *contour_with_thickness, start_end_idx=endpoint_idxs) + cc_mesh.add_contour(start_slice-slice_idx, *contour_with_thickness, start_end_idx=endpoint_idxs) if result is None: continue From 246fc77fbc16380ba062498ff9723e95a4750b6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Thu, 4 Dec 2025 12:10:04 +0100 Subject: [PATCH 40/68] Fix AC-PC localization update plotting in localization of ac pc --- CorpusCallosum/localization/inference.py | 29 +++++++++++++----------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/CorpusCallosum/localization/inference.py b/CorpusCallosum/localization/inference.py index 0d6e5716..78f7e595 100644 --- a/CorpusCallosum/localization/inference.py +++ b/CorpusCallosum/localization/inference.py @@ -183,8 +183,7 @@ def run_inference( inputs = transformed_original.to(device) inputs = inputs.transpose(0, 1) - batch_size, channels, height, width = inputs.shape - inputs = inputs.unfold(0, 3, 1).swapdims(0, 1).reshape(-1, 3*channels, height, width) + inputs = inputs.unfold(0, 3, 1).transpose(1, -1)[..., 0] # Run inference with torch.no_grad(): @@ -199,6 +198,7 @@ def run_inference_on_slice( model: DenseNet, image_slice: np.ndarray, center_pt: np.ndarray, + num_iterations: int = 2, debug_output: str | None = None, ) -> tuple[npt.NDArray[float], npt.NDArray[float]]: """Run inference on a single slice to detect AC and PC points. @@ -211,23 +211,26 @@ def run_inference_on_slice( 3D image mid-slices to run inference on in RAS. center_pt : np.ndarray Initial center point estimate for cropping. + num_iterations : int, default=2 + Number of refinement iterations to run. debug_output : str, optional Path to save debug visualization, by default None. Returns ------- ac_coords : np.ndarray - Detected AC coordinates with shape (2,) containing its [y,x] positions. + Detected AC voxel coordinates with shape (2,) containing its [y,x] positions. pc_coords : np.ndarray - Detected PC coordinates with shape (2,) containing its [y,x] positions. + Detected PC voxel coordinates with shape (2,) containing its [y,x] positions. """ # Run inference - pc_coords, ac_coords, *_ = run_inference(model, image_slice, center_pt) - center_pt = np.mean(np.concatenate([ac_coords, pc_coords], axis=0), axis=0) - pc_coords, ac_coords, _, (crop_left, crop_top) = run_inference(model, image_slice, center_pt) - pc_coords = np.mean(pc_coords, axis=0) - ac_coords = np.mean(ac_coords, axis=0) + for i in range(num_iterations): + pc_coords, ac_coords, _, (crop_left, crop_top) = run_inference(model, image_slice, center_pt) + center_pt = np.mean(np.stack([ac_coords, pc_coords], axis=0), axis=(0, 1)) + # average ac and pc coords across sagittal slices + pc_coords = np.mean(pc_coords, axis=0, keepdims=True) + ac_coords = np.mean(ac_coords, axis=0, keepdims=True) if debug_output is not None: import matplotlib.pyplot as plt @@ -235,12 +238,12 @@ def run_inference_on_slice( fig, ax = plt.subplots(1, 1, figsize=(10, 8)) ax.imshow(image_slice[image_slice.shape[0]//2, :, :], cmap='gray') # Plot points on all views - ax.scatter(pc_coords[1], pc_coords[0], c='r', marker='x', label='PC') - ax.scatter(ac_coords[1], ac_coords[0], c='b', marker='x', label='AC') + ax.scatter(pc_coords[:, 1], pc_coords[:, 0], c='r', marker='x', label='PC') + ax.scatter(ac_coords[:, 1], ac_coords[:, 0], c='b', marker='x', label='AC') # make a box where the crop is - ax.add_patch(Rectangle((crop_top, crop_left), 64, 64, fill=False, color='r', linewidth=2)) + ax.add_patch(Rectangle((crop_top, crop_left), 64, 64, fill=False, color='r', linewidth=2)) plt.savefig(debug_output, bbox_inches='tight') plt.close() - return ac_coords, pc_coords + return ac_coords[0], pc_coords[0] From 0e0f5cffe4dc4f7bbb26a7f2111715f0efec2f43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Wed, 3 Dec 2025 19:09:27 +0100 Subject: [PATCH 41/68] Rework the generation of transformation matrices and make them more unified and clearer also fixing broken transformation matrices Renaming of functions and variables for better description of function prepare processing of non-isotropic images Use lapy 1.4 function to save surface in vox2ras_tkr coordinates Rename centroid_registration to register_centroids_to_fsavg Fix vox_size arguments in multiple places Use FastSurfer's thread_executor to perform parallelization Fix generation of aseg_auto_statsfile in run_fastsurfer Add a commented solution to vectorize subsegment_contour.subsegment_midline_orthogonal Improve Error messages Update/fix docstrings improve type annotations, including additional overload signatures Add a CCMeasurementsDict TypedDict for CC Measures Reformatting Various fixes to CC generation (messages to be edited) --- CorpusCallosum/data/fsaverage_cc_template.py | 25 +- .../data/generate_fsaverage_centroids.py | 4 +- CorpusCallosum/data/read_write.py | 15 +- CorpusCallosum/fastsurfer_cc.py | 265 ++++++++++------- CorpusCallosum/localization/inference.py | 8 +- CorpusCallosum/segmentation/inference.py | 18 +- .../segmentation_postprocessing.py | 64 ++-- CorpusCallosum/shape/endpoint_heuristic.py | 117 ++++---- CorpusCallosum/shape/mesh.py | 62 ++-- CorpusCallosum/shape/postprocessing.py | 279 ++++++++++-------- CorpusCallosum/shape/subsegment_contour.py | 52 +++- CorpusCallosum/shape/thickness.py | 190 ++++++------ CorpusCallosum/transforms/segmentation.py | 67 +++-- CorpusCallosum/utils/mapping_helpers.py | 38 ++- CorpusCallosum/utils/visualization.py | 52 ++-- run_fastsurfer.sh | 2 +- 16 files changed, 728 insertions(+), 530 deletions(-) diff --git a/CorpusCallosum/data/fsaverage_cc_template.py b/CorpusCallosum/data/fsaverage_cc_template.py index a49e5c61..e307dedd 100644 --- a/CorpusCallosum/data/fsaverage_cc_template.py +++ b/CorpusCallosum/data/fsaverage_cc_template.py @@ -120,18 +120,19 @@ def load_fsaverage_cc_template() -> tuple[ # Use the smoothed mask for further processing cc_mask = cc_mask_smoothed.astype(int) * 192 - (_, contour_with_thickness, anterior_endpoint_idx, - posterior_endpoint_idx) = recon_cc_surf_measure(segmentation=cc_mask[None], - slice_idx=0, - ac_coords=AC, - pc_coords=PC, - affine=fsaverage_seg.affine, - num_thickness_points=100, - subdivisions=[1/6, 1/2, 2/3, 3/4], - subdivision_method="shape", - contour_smoothing=5, - vox_size=1) - outside_contour = contour_with_thickness[0].T + _, contour_with_thickness, (anterior_endpoint_idx, posterior_endpoint_idx) = recon_cc_surf_measure( + segmentation=cc_mask[None], + slice_idx=0, + ac_coords=AC, + pc_coords=PC, + affine=fsaverage_seg.affine, + num_thickness_points=100, + subdivisions=[1/6, 1/2, 2/3, 3/4], + subdivision_method="shape", + contour_smoothing=5, + vox_size=(1., 1., 1.), # fsaverage is in 1mm isotropic + ) + outside_contour = contour_with_thickness[:2].T # make sure the CC stays in shape despite smoothing by moving endpoints outwards diff --git a/CorpusCallosum/data/generate_fsaverage_centroids.py b/CorpusCallosum/data/generate_fsaverage_centroids.py index 9b7abf74..4dd874ac 100644 --- a/CorpusCallosum/data/generate_fsaverage_centroids.py +++ b/CorpusCallosum/data/generate_fsaverage_centroids.py @@ -26,7 +26,7 @@ import nibabel as nib import numpy as np -from read_write import convert_numpy_to_json_serializable, get_centroids_from_nib +from read_write import convert_numpy_to_json_serializable, calc_ras_centroids_from_seg import FastSurferCNN.utils.logging as logging @@ -83,7 +83,7 @@ def main() -> None: # Extract centroids logger.info("Extracting centroids from fsaverage...") - centroids_dst = get_centroids_from_nib(fsaverage_nib) + centroids_dst = calc_ras_centroids_from_seg(fsaverage_nib) logger.info(f"Found {len(centroids_dst)} anatomical structures with centroids") diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py index abc5bf01..a0621263 100644 --- a/CorpusCallosum/data/read_write.py +++ b/CorpusCallosum/data/read_write.py @@ -21,6 +21,7 @@ from numpy import typing as npt import FastSurferCNN.utils.logging as logging +from FastSurferCNN.utils import AffineMatrix4x4 from FastSurferCNN.utils.parallel import thread_executor @@ -33,9 +34,9 @@ class FSAverageHeader(TypedDict): logger = logging.get_logger(__name__) -def get_centroids_from_nib(seg_img: nib.analyze.SpatialImage, label_ids: list[int] | None = None) \ +def calc_ras_centroids_from_seg(seg_img: nib.analyze.SpatialImage, label_ids: list[int] | None = None) \ -> dict[int, np.ndarray | None]: - """Get centroids of segmentation labels in RAS coordinates. + """Get centroids of segmentation labels in RAS coordinates, accepts any affine/data layout. Parameters ---------- @@ -51,7 +52,7 @@ def get_centroids_from_nib(seg_img: nib.analyze.SpatialImage, label_ids: list[in """ # Get segmentation data and affine seg_data: npt.NDArray[np.integer] = np.asarray(seg_img.dataobj) - vox2ras: npt.NDArray[float] = seg_img.affine + vox2ras: AffineMatrix4x4 = seg_img.affine # Get unique labels if label_ids is None: @@ -82,7 +83,7 @@ def convert_numpy_to_json_serializable(obj: object) -> object: Parameters ---------- - obj : object + obj : dict, list, array, number, serializable Object to convert to JSON serializable type. Returns @@ -154,7 +155,7 @@ def load_fsaverage_affine(affine_path: str | Path) -> npt.NDArray[float]: return affine_matrix -def load_fsaverage_data(data_path: str | Path) -> tuple[npt.NDArray[float], FSAverageHeader, npt.NDArray[float]]: +def load_fsaverage_data(data_path: str | Path) -> tuple[AffineMatrix4x4, FSAverageHeader, AffineMatrix4x4]: """Load fsaverage affine matrix and header fields from static JSON file. Parameters @@ -164,7 +165,7 @@ def load_fsaverage_data(data_path: str | Path) -> tuple[npt.NDArray[float], FSAv Returns ------- - affine_matrix : np.ndarray + affine_matrix : AffineMatrix4x4 4x4 affine transformation matrix. header_fields : dict Header fields needed for LTA: @@ -176,7 +177,7 @@ def load_fsaverage_data(data_path: str | Path) -> tuple[npt.NDArray[float], FSAv 3x3 direction cosines matrix. - Pxyz_c : np.ndarray RAS center coordinates [x,y,z]. - vox2ras_tkr : np.ndarray + vox2ras_tkr : AffineMatrix4x4 Voxel to RAS tkr-space transformation matrix. Raises diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 0bdf17c5..16ea8fa2 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -16,6 +16,7 @@ import argparse import json from collections.abc import Iterable +from functools import partial from pathlib import Path from time import perf_counter_ns from typing import Literal, TypeVar, cast @@ -25,6 +26,7 @@ import torch from monai.networks.nets import DenseNet from numpy import typing as npt +from scipy.ndimage import affine_transform from CorpusCallosum.data.constants import ( CC_LABEL, @@ -37,8 +39,8 @@ ) from CorpusCallosum.data.read_write import ( FSAverageHeader, + calc_ras_centroids_from_seg, convert_numpy_to_json_serializable, - get_centroids_from_nib, load_fsaverage_centroids, load_fsaverage_data, ) @@ -49,6 +51,7 @@ SliceSelection, SubdivisionMethod, check_area_changes, + create_sag_slice_vox2vox, make_subdivision_mask, recon_cc_surf_measures_multi, ) @@ -56,18 +59,17 @@ apply_transform_to_pt, apply_transform_to_volume, calc_mapping_to_standard_space, - interpolate_midplane, map_softlabels_to_orig, ) -from FastSurferCNN.data_loader.conform import is_conform +from FastSurferCNN.data_loader.conform import conform, is_conform from FastSurferCNN.segstats import HelpFormatter -from FastSurferCNN.utils import logging +from FastSurferCNN.utils import AffineMatrix4x4, logging from FastSurferCNN.utils.arg_types import path_or_none from FastSurferCNN.utils.common import SubjectDirectory, find_device +from FastSurferCNN.utils.lta import write_lta from FastSurferCNN.utils.parallel import shutdown_executors, thread_executor from FastSurferCNN.utils.parser_defaults import modify_argument from recon_surf.align_points import find_rigid -from recon_surf.lta import write_lta logger = logging.get_logger(__name__) _TPathLike = TypeVar("_TPathLike", str, Path, Literal[None]) @@ -353,9 +355,8 @@ def options_parse() -> argparse.Namespace: return args -def centroid_registration(aseg_nib: nib.analyze.SpatialImage) -> tuple[ - npt.NDArray[float], npt.NDArray[float], npt.NDArray[float], FSAverageHeader, npt.NDArray[float] -]: +def register_centroids_to_fsavg(aseg_nib: nib.analyze.SpatialImage) \ + -> tuple[AffineMatrix4x4, AffineMatrix4x4, AffineMatrix4x4, FSAverageHeader, AffineMatrix4x4]: """Perform centroid-based registration between subject and fsaverage space. Computes a rigid transformation between the subject's segmentation and fsaverage space @@ -368,15 +369,15 @@ def centroid_registration(aseg_nib: nib.analyze.SpatialImage) -> tuple[ Returns ------- - orig_fsaverage_vox2vox : np.ndarray + aseg2fsaverage_vox2vox : AffineMatrix4x4 Transformation matrix from original to fsaverage voxel space. - orig_fsaverage_ras2ras : np.ndarray + aseg2fsaverage_ras2ras : AffineMatrix4x4 Transformation matrix from original to fsaverage RAS space. - fsaverage_hires_affine : np.ndarray + fsaverage_hires_vox2ras : AffineMatrix4x4 High-resolution fsaverage affine matrix. fsaverage_header : FSAverageHeader FSAverage header fields for LTA writing. - vox2ras_tkr : np.ndarray + fsaverage_vox2ras_tkr : AffineMatrix4x4 Voxel to RAS tkr-space transformation matrix. Notes @@ -388,36 +389,43 @@ def centroid_registration(aseg_nib: nib.analyze.SpatialImage) -> tuple[ logger.info("Starting centroid registration") # Load pre-computed fsaverage centroids and data from static files - centroids_dst = load_fsaverage_centroids(FSAVERAGE_CENTROIDS_PATH) fsaverage_data_future = thread_executor().submit(load_fsaverage_data, FSAVERAGE_DATA_PATH) + ras_centroids_dst = load_fsaverage_centroids(FSAVERAGE_CENTROIDS_PATH) - centroids_mov = get_centroids_from_nib(aseg_nib, label_ids=list(centroids_dst.keys())) + ras_centroids_mov = calc_ras_centroids_from_seg(aseg_nib, label_ids=list(ras_centroids_dst.keys())) # get the set of joint labels - joint_centroid_labels = [lbl for lbl, v in centroids_mov.items() if v is not None] + joint_centroid_labels = [lbl for lbl, v in ras_centroids_mov.items() if v is not None] - centroids_mov = np.array([centroids_mov[lbl] for lbl in joint_centroid_labels]).T - centroids_dst = np.array([centroids_dst[lbl] for lbl in joint_centroid_labels]).T + ras_centroids_mov = np.array([ras_centroids_mov[lbl] for lbl in joint_centroid_labels]).T + ras_centroids_dst = np.array([ras_centroids_dst[lbl] for lbl in joint_centroid_labels]).T - orig_fsaverage_ras2ras: npt.NDArray[float] = find_rigid(p_mov=centroids_mov.T, p_dst=centroids_dst.T) + aseg2fsaverage_ras2ras: AffineMatrix4x4 = find_rigid(p_mov=ras_centroids_mov.T, p_dst=ras_centroids_dst.T) # make affine that increases resolution to orig resolution - resolution_trans: npt.NDArray[float] = np.diagflat(list(aseg_nib.header.get_zooms()[:3]) + [1]).astype(float) + aseg_zooms = list(nib.as_closest_canonical(aseg_nib).header.get_zooms()[:3]) + resolution_trans: AffineMatrix4x4 = np.diagflat([aseg_zooms[0], aseg_zooms[2], aseg_zooms[1], 1]).astype(float) - fsaverage_affine, fsaverage_header, vox2ras_tkr = fsaverage_data_future.result() - _highres_fsaverage: npt.NDArray[float] = np.linalg.inv(resolution_trans @ fsaverage_affine) - orig_fsaverage_vox2vox: npt.NDArray[float] = _highres_fsaverage @ orig_fsaverage_ras2ras @ aseg_nib.affine - fsaverage_hires_affine: npt.NDArray[float] = resolution_trans @ fsaverage_affine + fsaverage_vox2ras, fsavg_header, vox2ras_tkr = fsaverage_data_future.result() + fsavg_header["delta"] = np.asarray([aseg_zooms[0], aseg_zooms[2], aseg_zooms[1]]) # vox sizes in lia + # fsavg_hires_vox2ras translation should be 128 always (independent of resolution) + fsavg_hires_vox2ras: AffineMatrix4x4 = np.concatenate( + [(resolution_trans @ fsaverage_vox2ras)[:, :3], fsaverage_vox2ras[:, 3:4]], + axis=1, + ) + fsavg_header["dims"] = np.ceil(fsavg_header["dims"] @ np.linalg.inv(resolution_trans[:3, :3])).astype(int).tolist() + + aseg2fsavg_vox2vox: AffineMatrix4x4 = np.linalg.inv(fsavg_hires_vox2ras) @ aseg2fsaverage_ras2ras @ aseg_nib.affine logger.info("Centroid registration successful!") - return orig_fsaverage_vox2vox, orig_fsaverage_ras2ras, fsaverage_hires_affine, fsaverage_header, vox2ras_tkr + return aseg2fsavg_vox2vox, aseg2fsaverage_ras2ras, fsavg_hires_vox2ras, fsavg_header, vox2ras_tkr def localize_ac_pc( - midslices: np.ndarray, + orig_data: np.ndarray, aseg_nib: nib.analyze.SpatialImage, - orig_fsaverage_vox2vox: npt.NDArray[float], + orig2midslice_vox2vox: AffineMatrix4x4, model_localization: DenseNet, - num_slices_to_analyze: int + resample_shape: tuple[int, int, int], ) -> tuple[npt.NDArray[float], npt.NDArray[float]]: """Localize anterior and posterior commissure points in the brain. @@ -426,15 +434,15 @@ def localize_ac_pc( Parameters ---------- - midslices : np.ndarray - Array of mid-sagittal slices. + orig_data : np.ndarray + Array of intensity data. aseg_nib : nibabel.analyze.SpatialImage Subject's segmentation image in native subject space. - orig_fsaverage_vox2vox : np.ndarray + orig2midslice_vox2vox : np.ndarray Transformation matrix from subject/native space to fsaverage space (in lia). model_localization : DenseNet Trained model for AC-PC detection. - num_slices_to_analyze : int + resample_shape : 3-tuple of ints Number of slices to process. Returns @@ -444,17 +452,27 @@ def localize_ac_pc( pc_coords : np.ndarray Coordinates of the posterior commissure. """ - - # get center of third ventricle from aseg and map to fsaverage space + num_slices_to_analyze = resample_shape[0] + resample_shape = (num_slices_to_analyze + 2,) + resample_shape[1:] # 2 for context slices + _midslices_fut = thread_executor().submit( + affine_transform, + orig_data, + np.linalg.inv(orig2midslice_vox2vox), # inverse is required for affine_transform + output_shape=resample_shape, + order=2, # unclear, why this is not order=3 + mode="constant", + cval=0, + prefilter=True, # unclear, why we are using a smoothing filter here + ) + + # get center of third ventricle from aseg and map to fsaverage space (voxel coordinates) third_ventricle_mask = np.asarray(aseg_nib.dataobj) == THIRD_VENTRICLE_LABEL third_ventricle_center = np.argwhere(third_ventricle_mask).mean(axis=0) - third_ventricle_center_vox = apply_transform_to_pt(third_ventricle_center, orig_fsaverage_vox2vox, inv=False) + third_ventricle_center_vox = apply_transform_to_pt(third_ventricle_center, orig2midslice_vox2vox, inv=False) - # get 5 mm of slices output with 3 slices per inference - midslices_start = midslices.shape[0] // 2 - num_slices_to_analyze // 2 - 1 - middle_slices_localization = midslices[midslices_start:midslices_start + num_slices_to_analyze + 3] + # get 5 mm of slices with 3 slices per inference (cropping num_slices_to_analyze + 2 slices around the center) ac_coords, pc_coords = localization_inference.run_inference_on_slice( - model_localization, middle_slices_localization, third_ventricle_center_vox[1:], + model_localization, _midslices_fut.result(), third_ventricle_center_vox[1:], ) return ac_coords, pc_coords @@ -466,7 +484,6 @@ def segment_cc( pc_coords: npt.NDArray[float], aseg_nib: "nib.Nifti1Image", model_segmentation: "torch.nn.Module", - slices_to_analyze: int, ) -> tuple[npt.NDArray[bool], npt.NDArray[float]]: """Segment the corpus callosum using a trained model. @@ -485,8 +502,6 @@ def segment_cc( Subject's cc_seg_labels image. model_segmentation : torch.nn.Module Trained model for CC cc_seg_labels. - slices_to_analyze : int - Number of slices to process. Returns ------- @@ -495,15 +510,12 @@ def segment_cc( cc_softlabels : np.ndarray Soft cc_seg_labels probabilities. """ - # get 5 mm of slices output with 9 slices per inference - midslices_start = midslices.shape[0] // 2 - slices_to_analyze // 2 - 4 - middle_slices_slab = midslices[midslices_start:midslices_start + slices_to_analyze + 9] pre_clean_segmentation, inputs, cc_softlabels = segmentation_inference.run_inference_on_slice( model_segmentation, - middle_slices_slab, + midslices, ac_center=ac_coords, pc_center=pc_coords, - voxel_size=aseg_nib.header.get_zooms()[0], + voxel_size=nib.as_closest_canonical(aseg_nib).header.get_zooms()[2:0:-1], # convert from RAS to LIA ) cc_seg_labels, cc_volume_mask = segmentation_postprocessing.clean_cc_segmentation(pre_clean_segmentation) @@ -674,78 +686,108 @@ def main( #### setup variables io_futures = [] + _aseg_fut = thread_executor().submit(nib.load, sd.filename_by_attribute("aseg_name")) orig = cast(nib.analyze.SpatialImage, nib.load(sd.conf_name)) - # 5 mm around the midplane (making sure to get rl by as_closest_canonical) - vox_size = nib.as_closest_canonical(orig).header.get_zooms()[0] - slices_to_analyze = int(np.ceil(5 / vox_size)) + # check that the image is conformed, i.e. isotropic 1mm voxels, 256^3 size, LIA orientation + if not is_conform(orig, vox_size=None, img_size=None, orientation=None): + logger.info("Internally conforming orig to soft-LIA.") + orig = conform(orig, vox_size=None, img_size=None, orientation=None) + + # 5 mm around the midplane (guaranteed to be aligned RAS by as_closest_canonical) + vox_size_ras: tuple[float, float, float] = nib.as_closest_canonical(orig).header.get_zooms() + vox_size = vox_size_ras[0], vox_size_ras[2], vox_size_ras[1] # convert from RAS to LIA + slices_to_analyze = int(np.ceil(5 / vox_size[0])) + # slices_to_analyze must be odd if slices_to_analyze % 2 == 0: slices_to_analyze += 1 logger.info( - f"Segmenting {slices_to_analyze} slices (5 mm width at {vox_size} mm resolution, " + f"Segmenting {slices_to_analyze} slices (5 mm width at {vox_size[0]} mm resolution, " "center around the mid-sagittal plane)" ) - if not is_conform(orig, vox_size='min', img_size=None): - if is_conform(orig, vox_size=None, img_size=None): - logger.warning("fastsurfer_cc currently requires isotropic images.") - logger.error("MRI is not conformed, please run conform.py or mri_convert to conform the image.") - sys.exit(1) - # load models device = find_device(device) logger.info(f"Using device: {device}") logger.info("Loading models") - model_localization = localization_inference.load_model(device=device) - model_segmentation = segmentation_inference.load_model(device=device) + _model_localization = thread_executor().submit(localization_inference.load_model, device=device) + _model_segmentation = thread_executor().submit(segmentation_inference.load_model, device=device) - aseg_nib = cast(nib.analyze.SpatialImage, nib.load(sd.filename_by_attribute("aseg_name"))) + aseg_img = cast(nib.analyze.SpatialImage, _aseg_fut.result()) + + if not np.allclose(aseg_img.affine, orig.affine): + logger.error("Input MRI and segmentation are not aligned! Please check your input files.") + sys.exit(1) logger.info("Performing centroid registration to fsaverage space") - orig2fsavg_vox2vox, orig2fsavg_ras2ras, fsavg_affine, fsavg_header, fsavg_vox2ras_tkr = centroid_registration( - aseg_nib + orig2fsavg_vox2vox, orig2fsavg_ras2ras, fsavg_vox2ras, fsavg_header, fsavg_vox2ras_tkr = ( + register_centroids_to_fsavg(aseg_img) ) - logger.info("Interpolating midplane slices") - # this is a fast interpolation to not block the main thread - midslices = interpolate_midplane(orig, orig2fsavg_vox2vox, slices_to_analyze) - # start saving upright volume + # start saving upright volume, this is the image in fsaverage space but not yet oriented via AC-PC if sd.has_attribute("upright_volume"): + # upright == fsaverage-aligned io_futures.append( thread_executor().submit( apply_transform_to_volume, orig, orig2fsavg_vox2vox, - fsavg_affine, + fsavg_vox2ras, output_path=sd.filename_by_attribute("upright_volume"), - output_size=np.array([256, 256, 256]), + output_size=fsavg_header["dims"], ) ) + # calculate affine for segmentation volume + affine_x_offset = partial(create_sag_slice_vox2vox, fsaverage_middle=FSAVERAGE_MIDDLE / vox_size[0]) + fsavg2midslab_in_vox2vox: AffineMatrix4x4 = affine_x_offset(slices_to_analyze // 2) + # first, midslice->fsaverage in vox2vox, then vox2ras in fsaverage space + midslab_vox2ras: AffineMatrix4x4 = fsavg_vox2ras @ np.linalg.inv(fsavg2midslab_in_vox2vox) + + # calculate vox2vox for input resampling volumes + def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4: + fsavg2midslab = affine_x_offset(slices_to_analyze // 2 + extra_slices // 2) + # first, orig->fsaverage in vox2vox, then fsaverage->midslab in vox2vox + return fsavg2midslab @ orig2fsavg_vox2vox + #### do localization and segmentation inference logger.info("Starting AC/PC localization") + target_shape = (slices_to_analyze, fsavg_header["dims"][1], fsavg_header["dims"][2]) ac_coords, pc_coords = localize_ac_pc( - midslices, aseg_nib, orig2fsavg_vox2vox, model_localization, slices_to_analyze, + np.asarray(orig.dataobj), + aseg_img, + _orig2midslab_vox2vox(extra_slices=2), + _model_localization.result(), + target_shape, ) logger.info("Starting corpus callosum segmentation") + target_shape = (slices_to_analyze + 8, fsavg_header["dims"][1], fsavg_header["dims"][2]) # 8 for context slices + midslices = affine_transform( + np.asarray(orig.dataobj), + np.linalg.inv(_orig2midslab_vox2vox(extra_slices=8)), # inverse is required for affine_transform + output_shape=target_shape, + order=2, # @ClePol unclear, why this is not order=3 + mode="constant", + cval=0, + prefilter=True, # unclear, why we are using a smoothing filter here + ) cc_fn_seg_labels, cc_fn_softlabels = segment_cc( - midslices, ac_coords, pc_coords, aseg_nib, model_segmentation, slices_to_analyze, + midslices, + ac_coords, + pc_coords, + aseg_img, + _model_segmentation.result(), ) - # calculate affine for segmentation volume - orig_to_seg = np.eye(4) - orig_to_seg[0, 3] = -FSAVERAGE_MIDDLE + slices_to_analyze // 2 - seg_affine = fsavg_affine @ np.linalg.inv(orig_to_seg) - - # save softlabels + # save segmentation softlabels for i, (attr, name) in enumerate((("background",) * 2, ("cc", "Corpus Callosum"), ("fn", "Fornix"))): if sd.has_attribute(f"cc_softlabels_{attr}"): logger.info(f"Saving {name} softlabels to {sd.filename_by_attribute(f'cc_softlabels_{attr}')}") io_futures.append(thread_executor().submit( nib.save, - nib.MGHImage(cc_fn_softlabels[..., i], seg_affine, orig.header), + nib.MGHImage(cc_fn_softlabels[..., i], midslab_vox2ras, orig.header), sd.filename_by_attribute(f"cc_softlabels_{attr}"), )) @@ -756,7 +798,7 @@ def main( slice_results, slice_io_futures = recon_cc_surf_measures_multi( segmentation=cc_fn_seg_labels, slice_selection=slice_selection, - upright_affine=fsavg_affine, + fsavg_vox2ras=fsavg_vox2ras, midslices=midslices, ac_coords=ac_coords, pc_coords=pc_coords, @@ -764,7 +806,7 @@ def main( subdivisions=subdivisions, subdivision_method=subdivision_method, contour_smoothing=contour_smoothing, - vox_size=orig.header.get_zooms(), + vox_size=vox_size, vox2ras_tkr=fsavg_vox2ras_tkr, subject_dir=sd, ) @@ -779,7 +821,6 @@ def main( # Get middle slice result middle_slice_result = slice_results[len(slice_results) // 2] - if len(middle_slice_result['split_contours']) <= 5: cc_subseg_midslice = make_subdivision_mask( cc_fn_seg_labels.shape[1:], @@ -787,25 +828,27 @@ def main( orig.header.get_zooms(), ) else: - logger.warning("Too many subsegments for lookup table, skipping sub-divion of output segmentation.") + logger.warning("Too many subsegments for lookup table, skipping sub-division of output segmentation.") cc_subseg_midslice = None - # map soft labels to original space (in parallel because this takes a while, and we only do it to save the labels) - io_futures.append(thread_executor().submit( - map_softlabels_to_orig, - cc_fn_softlabels=cc_fn_softlabels, - orig_fsaverage_vox2vox=orig2fsavg_vox2vox, - orig=orig, - orig_space_segmentation_path=segmentation_in_orig, - fsaverage_middle=FSAVERAGE_MIDDLE, - cc_subseg_midslice=cc_subseg_midslice, - )) + # save segmentation labels, this if sd.has_attribute("cc_segmentation"): io_futures.append(thread_executor().submit( nib.save, - nib.MGHImage(cc_fn_seg_labels, seg_affine, orig.header), + nib.MGHImage(cc_fn_seg_labels, midslab_vox2ras, orig.header), sd.filename_by_attribute("cc_segmentation"), )) + # map soft labels to original space (in parallel because this takes a while, and we only do it to save the labels) + if sd.has_attribute("cc_orig_segfile"): + io_futures.append(thread_executor().submit( + map_softlabels_to_orig, + cc_fn_softlabels=cc_fn_softlabels, + orig_fsaverage_vox2vox=orig2fsavg_vox2vox, + orig=orig, + orig_space_segmentation_path=sd.filename_by_attribute("cc_orig_segfile"), + fsaverage_middle=FSAVERAGE_MIDDLE, + cc_subseg_midslice=cc_subseg_midslice, + )) METRICS = [ "areas", @@ -824,10 +867,8 @@ def main( # Create enhanced output dictionary with all slice results per_slice_output_dict = { - "slices": [ - convert_numpy_to_json_serializable({metric: result[metric] for metric in METRICS}) - for result in slice_results - ], + "slices": [convert_numpy_to_json_serializable({metric: result[metric] for metric in METRICS}) + for result in slice_results], } ########## Save outputs ########## @@ -835,21 +876,25 @@ def main( if len(outer_contours) > 1: cc_volume_voxel = segmentation_postprocessing.get_cc_volume_voxel( desired_width_mm=5, - cc_mask=cc_fn_seg_labels == CC_LABEL, - voxel_size=orig.header.get_zooms() - ) - cc_volume_contour = segmentation_postprocessing.get_cc_volume_contour( - cc_contours=outer_contours, - voxel_size=orig.header.get_zooms() + cc_mask=np.equal(cc_fn_seg_labels, CC_LABEL), + voxel_size=vox_size, # in LIA order ) logger.info(f"CC volume voxel: {cc_volume_voxel}") - logger.info(f"CC volume contour: {cc_volume_contour}") + # FIXME: Create a proper mesh and use cc_mesh.volume for this volume + try: + cc_volume_contour = segmentation_postprocessing.get_cc_volume_contour( + cc_contours=outer_contours, + voxel_size=vox_size, # in LIA order + ) + logger.info(f"CC volume contour: {cc_volume_contour}") + except AssertionError as e: + logger.warning("Could not compute CC volume from contours, setting to NaN") + logger.exception(e) + cc_volume_contour = float('nan') additional_metrics["cc_5mm_volume"] = cc_volume_voxel additional_metrics["cc_5mm_volume_pv_corrected"] = cc_volume_contour - - # get ac and pc in all spaces ac_coords_3d = np.hstack((FSAVERAGE_MIDDLE, ac_coords)) pc_coords_3d = np.hstack((FSAVERAGE_MIDDLE, pc_coords)) @@ -865,7 +910,7 @@ def main( additional_metrics["ac_center_upright"] = ac_coords_3d additional_metrics["pc_center_upright"] = pc_coords_3d additional_metrics["slices_in_segmentation"] = slices_to_analyze - additional_metrics["voxel_size"] = [float(x) for x in orig.header.get_zooms()] + additional_metrics["voxel_size"] = np.asarray(orig.header.get_zooms(), dtype=float).tolist() additional_metrics["num_thickness_points"] = num_thickness_points additional_metrics["subdivision_method"] = subdivision_method additional_metrics["subdivision_ratios"] = subdivisions @@ -894,11 +939,12 @@ def main( if sd.has_attribute("upright_lta"): sd.filename_by_attribute("cc_mid_measures").parent.mkdir(exist_ok=True, parents=True) logger.info(f"Saving LTA to fsaverage space: {sd.filename_by_attribute('upright_lta')}") - io_futures.append(thread_executor().submit(write_lta, - sd.filename_by_attribute("upright_lta"), + io_futures.append(thread_executor().submit( + write_lta, + sd.filename_by_attribute("upright_lta"), orig2fsavg_ras2ras, sd.filename_by_attribute("aseg_name"), - aseg_nib.header, + aseg_img.header, "fsaverage", fsavg_header, )) @@ -908,7 +954,8 @@ def main( # save lta to standardized space (fsaverage + nodding + ac to center) orig2standardized_ras2ras = orig.affine @ np.linalg.inv(standardized2orig_vox2vox) @ np.linalg.inv(orig.affine) logger.info(f"Saving LTA to standardized space: {sd.filename_by_attribute('cc_orient_volume_lta')}") - io_futures.append(thread_executor().submit(write_lta, + io_futures.append(thread_executor().submit( + write_lta, sd.filename_by_attribute("cc_orient_volume_lta"), orig2standardized_ras2ras, sd.conf_name, diff --git a/CorpusCallosum/localization/inference.py b/CorpusCallosum/localization/inference.py index 78f7e595..fd6df1b1 100644 --- a/CorpusCallosum/localization/inference.py +++ b/CorpusCallosum/localization/inference.py @@ -139,7 +139,7 @@ def preprocess_volume( def run_inference( model: torch.nn.Module, image_volume: np.ndarray, - third_ventricle_center: np.ndarray, + patch_center: np.ndarray, device: torch.device | None = None, transform: transforms.Transform | None = None ) -> tuple[npt.NDArray[float], npt.NDArray[float], np.ndarray, tuple[int, int]]: @@ -152,7 +152,7 @@ def run_inference( Trained model for inference. image_volume : np.ndarray Input volume as numpy array. - third_ventricle_center : np.ndarray + patch_center : np.ndarray Initial center point estimate for cropping. device : torch.device, optional Device to run inference on, by default None. @@ -174,10 +174,10 @@ def run_inference( device = next(model.parameters()).device # prepend zero to third_ventricle_center - third_ventricle_center = np.concatenate([np.zeros(1), third_ventricle_center]) + patch_center_3d = np.concatenate([np.zeros(1), patch_center]) # Preprocess - t_dict = preprocess_volume(image_volume, third_ventricle_center, transform) + t_dict = preprocess_volume(image_volume, patch_center_3d, transform) transformed_original = t_dict['image'] inputs = transformed_original.to(device) diff --git a/CorpusCallosum/segmentation/inference.py b/CorpusCallosum/segmentation/inference.py index 70c63d86..a1a6cc5b 100644 --- a/CorpusCallosum/segmentation/inference.py +++ b/CorpusCallosum/segmentation/inference.py @@ -86,7 +86,7 @@ def run_inference( image_slice: np.ndarray, ac_center: np.ndarray, pc_center: np.ndarray, - voxel_size: float, + voxel_size: tuple[float, float], device: torch.device | None = None, transform: transforms.Transform | None = None ) -> tuple[npt.NDArray[int], npt.NDArray[float], npt.NDArray[float]]: @@ -102,8 +102,8 @@ def run_inference( Anterior commissure coordinates. pc_center : np.ndarray Posterior commissure coordinates. - voxel_size : float - Voxel size in mm. + voxel_size : a pair of floats + Voxel size fo inferior/superior and anterior/posterior direction in mm. device : torch.device or None, optional Device to run inference on, by default None. If None, uses the device of the model. @@ -126,8 +126,8 @@ def run_inference( to_discrete = transforms.AsDiscrete(argmax=True, to_onehot=3) # Preprocess slice - _inputs = torch.from_numpy(image_slice[:,None,:256,:256]) # artifact from training script - sample = {'image': _inputs, 'AC_center': ac_center, 'PC_center': pc_center, 'res': voxel_size} + _inputs = torch.from_numpy(image_slice[:,None]) #,:256,:256]) # artifact from training script + sample = {'image': _inputs, 'AC_center': ac_center, 'PC_center': pc_center, 'res': np.asarray(voxel_size)} sample_cropped = crop_around_acpc(sample) _inputs, to_pad = sample_cropped['image'], sample_cropped['to_pad'] _inputs = transforms.utils.rescale_array(_inputs, 0, 1, dtype=np.float32).to(device) @@ -139,7 +139,7 @@ def run_inference( # Post-process outputs with torch.no_grad(): - scale_factors = torch.ones((_inputs.shape[0], 2), device=device) / voxel_size + scale_factors = torch.ones((_inputs.shape[0], 2), device=device) / torch.asarray([voxel_size], device=device) _logits = model(_inputs, scale_factor=scale_factors) _softlabels = transforms.Activations(softmax=True, dim=1)(_logits) @@ -265,7 +265,7 @@ def run_inference_on_slice( test_slice: np.ndarray, ac_center: npt.NDArray[float], pc_center: npt.NDArray[float], - voxel_size: float, + voxel_size: tuple[float, float], ) -> tuple[npt.NDArray[int], np.ndarray, npt.NDArray[float]]: """Run inference on a single slice. @@ -279,8 +279,8 @@ def run_inference_on_slice( Anterior commissure coordinates (Inferior and Anterior values). pc_center : npt.NDArray[float] Posterior commissure coordinates (Inferior and Posterior values). - voxel_size : float - Voxel size in mm. + voxel_size : a pair of floats + Voxel sizes in superior/inferior and anterior/posterior direction in mm. Returns ------- diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index 265ae634..c7cc970b 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -280,7 +280,7 @@ def connect_nearby_components(seg_arr: np.ndarray, max_connection_distance: floa def get_cc_volume_voxel( desired_width_mm: int, cc_mask: np.ndarray, - voxel_size: tuple[float, float, float] + voxel_size: tuple[float, float, float], ) -> float: """Calculate the volume of the corpus callosum in cubic millimeters. @@ -293,9 +293,9 @@ def get_cc_volume_voxel( desired_width_mm : int Desired width of the CC in millimeters. cc_mask : np.ndarray - Binary mask of the corpus callosum. - voxel_size : tuple[float, float, float] - Voxel size in millimeters (x, y, z). + Binary mask of the corpus callosum in LIA orientation. + voxel_size : triplet of floats + LIA-oriented Voxel size in millimeters (x, y, z). Returns ------- @@ -320,19 +320,20 @@ def get_cc_volume_voxel( # Calculate voxel volume - voxel_volume = np.prod(voxel_size) + voxel_volume: float = np.prod(voxel_size, dtype=float) + voxel_width: float = voxel_size[0] # Get width of CC mask in voxels by finding the extent in x dimension width_vox = np.sum(np.any(cc_mask, axis=(1,2))) # we are in LIA, so 0 is L/R resolution - width_mm = width_vox * voxel_size[0] + width_mm = width_vox * voxel_width if width_mm == desired_width_mm: return np.sum(cc_mask) * voxel_volume elif width_mm > desired_width_mm: # remainder on the left/right side of the CC mask - desired_width_vox = desired_width_mm / voxel_size[0] + desired_width_vox = desired_width_mm / voxel_width fraction_of_voxel_at_edge = (desired_width_vox % 1) / 2 if fraction_of_voxel_at_edge > 0: @@ -351,15 +352,18 @@ def get_cc_volume_voxel( else: raise ValueError(f"Width of CC segmentation is smaller than desired width: {width_mm} < {desired_width_mm}") -def get_cc_volume_contour(cc_contours: list[np.ndarray], - voxel_size: tuple[float, float, float]) -> float: + +def get_cc_volume_contour( + cc_contours: list[np.ndarray], + voxel_size: tuple[float, float, float], +) -> float: """Calculate the volume of the corpus callosum using Simpson's rule. Parameters ---------- cc_contours : list[np.ndarray] List of CC contours for each slice in the left-right direction. - voxel_size : tuple[float, float, float] + voxel_size : triplet of floats Voxel size in millimeters (x, y, z). Returns @@ -378,20 +382,28 @@ def get_cc_volume_contour(cc_contours: list[np.ndarray], using Simpson's rule. If the CC width is larger than desired_width_mm, the voxels on the edges are calculated as partial volumes to achieve the desired width. """ + # FIXME: This function is a shape-tool, it should therefore not be in segmentation.postprocessing... + # FIXME: this code currently produces volume estimates more that 50% off of the volume_based estimate in + # get_cc_volume_voxel... + if len(cc_contours) < 3: raise ValueError("Need at least 3 contours for Simpson's rule integration") - + + # FIXME: why can we not multiply by those numbers in line below other FIXME comment + # converting this to a warning for now... + if voxel_size[1] == voxel_size[2]: + logger.warning("voxel sizes in get_cc_volume_contour, currently volume must be isotropic!") # Calculate cross-sectional areas for each contour areas = [] - for contour in cc_contours: - contour = contour.copy() - assert voxel_size[1] == voxel_size[2], "volume must be isotropic" - contour *= voxel_size[1] # Calculate area using the shoelace formula for polygon area if contour.shape[1] < 3: areas.append(0.0) else: + # FIXME: we are multiplying by voxel size here and below "Convert from voxel^2 to mm^2", e.g. + # x = contour[0] * voxel_size[1] + # y = contour[1] * voxel_size[2] + contour = contour * voxel_size[1] x = contour[0] y = contour[1] # Shoelace formula: A = 0.5 * |sum(x_i * y_{i+1} - x_{i+1} * y_i)| @@ -399,31 +411,31 @@ def get_cc_volume_contour(cc_contours: list[np.ndarray], # Convert from voxel^2 to mm^2 area_mm2 = area * voxel_size[1] * voxel_size[2] # y * z voxel dimensions areas.append(area_mm2) - + areas = np.array(areas) - + # Calculate spacing between slices (left-right direction) lr_spacing = voxel_size[0] # x-direction voxel size - measurement_points = np.arange(-voxel_size[0]*(areas.shape[0]//2), - voxel_size[0]*((areas.shape[0]+1)//2), lr_spacing) - + measurement_points = np.arange(-voxel_size[0]*(areas.shape[0]//2), + voxel_size[0]*((areas.shape[0]+1)//2), lr_spacing) + + # FIXME: why interpolate at 0.25? Also, why do we need interpolaton at all? # interpolate areas at 0.25 and 5 - areas_interpolated = np.interp(x=[-2.5, 2.5], - xp=measurement_points, + areas_interpolated = np.interp(x=[-2.5, 2.5], + xp=measurement_points, fp=areas) - # remove measurement points that are outside of the desired range # not sure if this can happen, but let's be safe outside_range = (measurement_points < -2.5) | (measurement_points > 2.5) measurement_points = [-2.5] + measurement_points[~outside_range].tolist() + [2.5] areas = [areas_interpolated[0]] + areas[~outside_range].tolist() + [areas_interpolated[1]] - - + + # can also use trapezoidal rule return integrate.simpson(areas, x=measurement_points) - + def extract_largest_connected_component( seg_arr: np.ndarray, diff --git a/CorpusCallosum/shape/endpoint_heuristic.py b/CorpusCallosum/shape/endpoint_heuristic.py index d88f0e6a..eb5f6095 100644 --- a/CorpusCallosum/shape/endpoint_heuristic.py +++ b/CorpusCallosum/shape/endpoint_heuristic.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Literal, overload import lapy import numpy as np @@ -18,6 +19,8 @@ import skimage.measure from scipy.ndimage import label +from FastSurferCNN.utils import Vector2d + def smooth_contour(x: np.ndarray, y: np.ndarray, window_size: int) -> tuple[np.ndarray, np.ndarray]: """Smooth a contour using a moving average filter. @@ -136,8 +139,8 @@ def extract_cc_contour(cc_mask: np.ndarray, contour_smoothing: int = 5) -> np.nd ---------- cc_mask : np.ndarray Binary mask of the corpus callosum. - contour_smoothing : int, optional - Window size for contour smoothing, by default 5. + contour_smoothing : int, default=5 + Window size for contour smoothing. Returns ------- @@ -163,38 +166,62 @@ def extract_cc_contour(cc_mask: np.ndarray, contour_smoothing: int = 5) -> np.nd return contour + +@overload def get_endpoints( - cc_mask: np.ndarray, - AC_2d: np.ndarray, - PC_2d: np.ndarray, - resolution: float, - return_coordinates: bool = True, + cc_mask: np.ndarray[tuple[int, int], np.dtype[bool]], + ac_2d: Vector2d, + pc_2d: Vector2d, + resolution: tuple[float, float], + return_coordinates: Literal[True], contour_smoothing: int = 5 -) -> tuple[np.ndarray, np.ndarray, np.ndarray] | tuple[np.ndarray, int, int]: +) -> tuple[np.ndarray, tuple[int, int], tuple[Vector2d, Vector2d]]: ... + + +@overload +def get_endpoints( + cc_mask: np.ndarray[tuple[int, int], np.dtype[bool]], + ac_2d: Vector2d, + pc_2d: Vector2d, + resolution: tuple[float, float], + return_coordinates: Literal[False] = False, + contour_smoothing: int = 5 +) -> tuple[np.ndarray, tuple[int, int]]: ... + + +def get_endpoints( + cc_mask: np.ndarray[tuple[int, int], np.dtype[bool]], + ac_2d: Vector2d, + pc_2d: Vector2d, + resolution: tuple[float, float], + return_coordinates: bool = False, + contour_smoothing: int = 5 +): """Determine endpoints of CC by finding points closest to AC and PC. Parameters ---------- - cc_mask : np.ndarray + cc_mask : np.ndarray of shape (H, W) and type bool Binary mask of the corpus callosum. - AC_2d : np.ndarray + ac_2d : np.ndarray of shape (2,) and type float 2D coordinates of the anterior commissure. - PC_2d : np.ndarray + pc_2d : np.ndarray of shape (2,) and type float 2D coordinates of the posterior commissure. - resolution : float - Image resolution in mm. - return_coordinates : bool, optional - If True, return endpoint coordinates, otherwise return indices, by default True. - contour_smoothing : int, optional - Window size for contour smoothing, by default 5. + resolution : pair of floats + Inslice image resolution in mm (inferior/superior and anterior/posterior directions). + return_coordinates : bool, default=False + If True, return endpoint coordinates. + contour_smoothing : int, default=5 + Window size for contour smoothing. Returns ------- - tuple[np.ndarray, np.ndarray, np.ndarray] | tuple[np.ndarray, int, int] - If return_coordinates is True: - (contour, anterior_point, posterior_point). - If return_coordinates is False: - (contour, anterior_index, posterior_index). + contour_rotated : np.ndarray + The contour rotated to AC-PC alignment. + anterior_posterior_point_indices : pair of ints + Indices of anterior and posterior points in the contour. + anterior_posterior_point_coordinates : tuple[np.ndarray, np.ndarray] + Only if return_coordinates is True: Coordinates of anterior and posterior points rotated to AP-PC alignment. Notes ----- @@ -203,7 +230,7 @@ def get_endpoints( image_size = cc_mask.shape # Calculate angle between AC-PC line and horizontal using numpy - ac_pc_vector = PC_2d - AC_2d + ac_pc_vector = pc_2d - ac_2d horizontal_vector = np.array([0, -20]) # Calculate angle using dot product formula: cos(theta) = (a·b)/(|a||b|) dot_product = np.dot(ac_pc_vector, horizontal_vector) @@ -223,11 +250,11 @@ def get_endpoints( rot_matrix = np.array([[np.cos(-theta), -np.sin(-theta)], [np.sin(-theta), np.cos(-theta)]]) # Translate points to origin, rotate, then translate back - pc_centered = PC_2d - origin_point - ac_centered = AC_2d - origin_point + pc_centered = pc_2d - origin_point + ac_centered = ac_2d - origin_point - rotated_PC_2d = (rot_matrix @ pc_centered) + origin_point - rotated_AC_2d = (rot_matrix @ ac_centered) + origin_point + rotated_pc_2d = (rot_matrix @ pc_centered) + origin_point + rotated_ac_2d = (rot_matrix @ ac_centered) + origin_point # Add z=0 coordinate to make 3D, then remove it after resampling contour_3d = np.vstack([contour, np.zeros(contour.shape[1])]) @@ -236,28 +263,26 @@ def get_endpoints( contour = contour[:, :-1] - rotated_AC_2d = np.array(rotated_AC_2d).astype(float) - rotated_PC_2d = np.array(rotated_PC_2d).astype(float) + rotated_ac_2d = np.array(rotated_ac_2d).astype(float) + rotated_pc_2d = np.array(rotated_pc_2d).astype(float) # move posterior commisure 5 mm posterior - rotated_PC_2d = rotated_PC_2d + np.array([10 * resolution, -5 * resolution]) + # FIXME: why is the move 10mm inferior not commented? + # FIXME: multiplication means moving less for smaller voxels, why not division? + # changed to division, 5 mm / voxel size => number of voxels to move + rotated_pc_2d = rotated_pc_2d + np.array([10, -5]) / resolution # move anterior commisure 1.5 mm anterior - rotated_AC_2d = rotated_AC_2d + np.array([0, 5 * resolution]) + # FIXME: why does the documentation say 1.5mm when the code says 5mm? + rotated_ac_2d = rotated_ac_2d + np.array([0, 5]) / resolution # find point in contour closest to AC - AC_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_AC_2d[:, None], axis=0)) + ac_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_ac_2d[:, None], axis=0)) # find point in contour closest to PC - PC_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_PC_2d[:, None], axis=0)) + pc_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_pc_2d[:, None], axis=0)) # rotate startpoints to original orientation - # Create rotation matrix - rot_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) - - # rotate contour to original orientation - contour_rotated = np.zeros_like(contour) - origin_point = np.array(origin_point).astype(float) # Create rotation matrix rot_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) @@ -267,16 +292,8 @@ def get_endpoints( contour_rotated = (rot_matrix @ contour_centered) + origin_point[:, None] if return_coordinates: - AC_contour_point = contour[:, AC_startpoint_idx] - PC_contour_point = contour[:, PC_startpoint_idx] - - # Translate points to origin, rotate, then translate back - ac_centered = AC_contour_point - origin_point - pc_centered = PC_contour_point - origin_point - - start_point_A = (rot_matrix @ ac_centered) + origin_point - start_point_P = (rot_matrix @ pc_centered) + origin_point + start_point_ac, start_point_pc = contour_rotated[:, [ac_startpoint_idx, pc_startpoint_idx]].T - return contour_rotated, start_point_A, start_point_P + return contour_rotated, (ac_startpoint_idx, pc_startpoint_idx), (start_point_ac, start_point_pc) else: - return contour_rotated, AC_startpoint_idx, PC_startpoint_idx + return contour_rotated, (ac_startpoint_idx, pc_startpoint_idx) diff --git a/CorpusCallosum/shape/mesh.py b/CorpusCallosum/shape/mesh.py index 4e5f2e9c..4e889d59 100644 --- a/CorpusCallosum/shape/mesh.py +++ b/CorpusCallosum/shape/mesh.py @@ -26,6 +26,7 @@ from scipy.ndimage import gaussian_filter1d import FastSurferCNN.utils.logging as logging +from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE from CorpusCallosum.shape.endpoint_heuristic import smooth_contour from CorpusCallosum.shape.thickness import make_mesh_from_contour from FastSurferCNN.utils.common import suppress_stdout @@ -74,7 +75,7 @@ class CCMesh(lapy.TriaMesh): List of vertex indices where thickness was originally measured. """ - def __init__(self, num_slices): + def __init__(self, num_slices: int): """Initialize a CC_Mesh object. Parameters @@ -83,15 +84,21 @@ def __init__(self, num_slices): Number of slices in the corpus callosum mesh """ super().__init__(np.zeros((3, 3)), np.zeros((3, 3), dtype=int)) - self.contours = [None] * num_slices - self.thickness_values = [None] * num_slices - self.start_end_idx = [None] * num_slices - self.ac_coords = None - self.pc_coords = None - self.resolution = None + self.contours: list[np.ndarray | None] = [None] * num_slices + self.thickness_values: list[np.ndarray | None] = [None] * num_slices + self.start_end_idx: list[int | None] = [None] * num_slices + self.ac_coords: np.ndarray | None = None + self.pc_coords: np.ndarray | None = None + self.resolution: tuple[float, float, float] | None = None + # FIXME: v and t do not get properly initialized and all the data in the base class are basically unvalidated + # this class needs to be reworked to either: + # A) properly inherit from TriaMesh, calling super().__init__ with the correct values, or + # B) converting it into a Factory class that then outputs a correct TriaMesh object. + # Currently, there are no real behavior "guarantees" of objects, as the internal state of the object is + # very chaotic and uncontrolled with almost no safeguards (and/or debugging). self.v = None self.t = None - self.original_thickness_vertices = [None] * num_slices + self.original_thickness_vertices: list[np.ndarray | None] = [None] * num_slices def add_contour( self, @@ -137,13 +144,13 @@ def set_acpc_coords(self, ac_coords: np.ndarray, pc_coords: np.ndarray): self.ac_coords = ac_coords self.pc_coords = pc_coords - def set_resolution(self, resolution: float): + def set_resolution(self, resolution: tuple[float, float, float]): """Set the spatial resolution of the mesh. Parameters ---------- - resolution : float - Spatial resolution in millimeters. + resolution : triplet of floats + LIA-oriented spatial resolution of the mesh. """ self.resolution = resolution @@ -342,7 +349,7 @@ def plot_mesh( # Calculate z coordinates for each slice - use same calculation as in create_mesh lr_center = self.v[len(self.v) // 2][2] - z_coordinates = np.arange(num_slices) * self.resolution - (num_slices // 2) * self.resolution + lr_center + z_coordinates = (np.arange(num_slices) - (num_slices // 2)) * self.resolution[0] + lr_center for i in range(num_slices): if self.contours[i] is not None: @@ -747,9 +754,7 @@ def create_mesh(self, lr_center: float = 0, closed: bool = False, smooth: int = return # Calculate z coordinates for each slice - z_coordinates = ( - np.arange(len(valid_contours)) * self.resolution - (len(valid_contours) // 2) * self.resolution + lr_center - ) + z_coordinates = (np.arange(len(valid_contours)) - len(valid_contours) // 2) * self.resolution[0] + lr_center # Build vertices list with z-coordinates vertices = [] @@ -1190,6 +1195,7 @@ def set_mesh(self, self.fsinfo = None # Skip parent initialization since we have no faces else: + #FIXME: based on this call and CCMesh.__init__, this whole class probably needs a rework. super().__init__(np.vstack(vertices), np.vstack(faces)) if thickness_values is not None: @@ -1587,7 +1593,6 @@ def __make_parent_folder(filename: Path | str) -> None: def to_fs_coordinates( self, vox2ras_tkr: np.ndarray, - vox_size: tuple[float, float, float], ) -> None: """Convert mesh coordinates to FreeSurfer coordinate system. @@ -1595,16 +1600,18 @@ def to_fs_coordinates( ---------- vox2ras_tkr : np.ndarray 4x4 voxel to RAS tkr-space transformation matrix. - vox_size : 3-tuple of floats - Voxel size in millimeters (x, y, z). Notes ----- + Mesh coordinates seem to be in ASR (Anterior-Superior-Right) orientation, with the coordinate system origin on + *the* midslice. + The function: - 1. Converts coordinates from original to LSA orientation. - 2. Converts to voxel coordinates using voxel size. - 3. Centers LR coordinates and flips SI coordinates. - 4. Applies vox2ras_tkr transformation to get final coordinates. + 1. convert from mesh coordinates (LSA and voxel coordinates) to fsaverage voxel coordinates (LIA, origin). + a. Converts coordinates from ASR to LSA orientation. + b. Converts to voxel coordinates using voxel size. + c. Centers LR coordinates and flips SI coordinates. + 2. Applies vox2ras_tkr transformation to get final coordinates. """ # to voxel coordinates @@ -1613,9 +1620,13 @@ def to_fs_coordinates( # to LSA v_vox = v_vox[:, [2, 1, 0]] # to voxel - v_vox /= vox_size[0] + # FIXME: why are the vertex positions multiplied by voxel size here? + # removed => for center LR, now dividing by resolution => convert fsaverage middle from mm to vox + # => remove the conversion back to mm in the end + # all other operations are independent of order of operations (distributive) + # v_vox /= vox_size[0] # center LR - v_vox[:, 0] += 256 // 2 + v_vox[:, 0] += FSAVERAGE_MIDDLE / self.resolution[0] # flip SI v_vox[:, 1] = -v_vox[:, 1] @@ -1629,7 +1640,8 @@ def to_fs_coordinates( # Torig: mri_info --vox2ras-tkr orig.mgz # https://surfer.nmr.mgh.harvard.edu/fswiki/CoordinateSystems self.v = (vox2ras_tkr @ np.concatenate([v_vox, np.ones((self.v.shape[0], 1))], axis=1).T).T[:, :3] - self.v = self.v * vox_size[0] + # FIXME: why are the vertex positions multiplied by voxel size here? + # self.v = self.v * vox_size[0] def write_fssurf(self, filename: Path | str) -> None: """Write the mesh to a FreeSurfer surface file. diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index f3196af2..3dcd5454 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -13,10 +13,9 @@ # limitations under the License. import concurrent.futures from functools import partial -from typing import Literal, get_args +from typing import Literal, TypedDict, get_args import numpy as np -from numpy import typing as npt import FastSurferCNN.utils.logging as logging from CorpusCallosum.data.constants import CC_LABEL, FSAVERAGE_MIDDLE, SUBSEGMENT_LABELS @@ -32,6 +31,7 @@ ) from CorpusCallosum.shape.thickness import cc_thickness, convert_to_ras from CorpusCallosum.utils.visualization import plot_contours +from FastSurferCNN.utils import AffineMatrix4x4, Vector2d from FastSurferCNN.utils.common import SubjectDirectory, suppress_stdout, update_docstring from FastSurferCNN.utils.parallel import process_executor, thread_executor @@ -47,16 +47,65 @@ LIA_ORIENTATION[2,1] = -1 -def create_slice_affine(upright_affine: np.ndarray, slice_idx: int, fsaverage_middle: int) -> np.ndarray: - """Create slice-specific affine transformation matrix. +class CCMeasuresDict(TypedDict): + """TypedDict for corpus callosum measures. + + Attributes + ---------- + cc_index : float + Corpus callosum shape index. + circularity : float + Shape circularity measure. + areas : np.ndarray + Areas of subdivided regions. + midline_length : float + Length along the midline. + thickness : float + Array of thickness measurements. + curvature : float + Array of curvature measurements. + thickness_profile : np.ndarray of type float + Thickness measurements along the contour. + total_area : float + Total area of the CC. + total_perimeter : float + Total perimeter length. + split_contours : list of np.ndarray + Subdivided contour segments in AS-slice coordinates. + midline_equidistant : np.ndarray + Equidistant points along midline in AS-slice coordinates. + levelpaths : list of np.ndarray + Paths for thickness measurements in AS-slice coordinates. + slice_index : int + Index of the processed slice. + """ + cc_index: float + circularity: float + areas: np.ndarray + midline_length: float + thickness: float + curvature: float + thickness_profile: np.ndarray[tuple[int], np.dtype[float]] + total_area: float + total_perimeter: float + total_area: float + total_perimeter: float + split_contours: list[np.ndarray] + midline_equidistant: np.ndarray + levelpaths: list[np.ndarray] + slice_index: int + + +def create_sag_slice_vox2vox(slice_idx: int, fsaverage_middle: float) -> AffineMatrix4x4: + """Create slice-specific slice to full affine transformation matrix. + + Returns a volume to slice in volume affine. Parameters ---------- - upright_affine : np.ndarray - Base 4x4 affine transformation matrix. slice_idx : int Index of the slice to transform. - fsaverage_middle : int + fsaverage_middle : float Reference middle slice index in fsaverage space. Returns @@ -64,27 +113,27 @@ def create_slice_affine(upright_affine: np.ndarray, slice_idx: int, fsaverage_mi np.ndarray Modified 4x4 affine transformation matrix for the specific slice. """ - slice_affine = upright_affine.copy() - slice_affine[0, 3] = -fsaverage_middle + slice_idx - return slice_affine + slice2full_vox2vox: AffineMatrix4x4 = np.eye(4, dtype=float) + slice2full_vox2vox[0, 3] = -fsaverage_middle + slice_idx + return slice2full_vox2vox @update_docstring(SubdivisionMethod=str(get_args(SubdivisionMethod))[1:-1]) def recon_cc_surf_measures_multi( segmentation: np.ndarray, slice_selection: SliceSelection, - upright_affine: np.ndarray, + fsavg_vox2ras: np.ndarray, midslices: np.ndarray, ac_coords: np.ndarray, pc_coords: np.ndarray, num_thickness_points: int, subdivisions: list[float], subdivision_method: SubdivisionMethod, - contour_smoothing: float, + contour_smoothing: int, subject_dir: SubjectDirectory, - vox_size: tuple[float, float, float] | None = None, + vox_size: tuple[float, float, float], vox2ras_tkr: np.ndarray | None = None, -) -> tuple[list, list[concurrent.futures.Future]]: +) -> tuple[list[CCMeasuresDict], list[concurrent.futures.Future]]: """Surface reconstruction and metrics computation of corpus callosum slices based on selection mode. Parameters @@ -93,7 +142,7 @@ def recon_cc_surf_measures_multi( 3D segmentation array. slice_selection : str Which slices to process ('middle', 'all', or slice number). - upright_affine : np.ndarray + fsavg_vox2ras : np.ndarray Base affine transformation matrix (fsaverage, upright space). midslices : np.ndarray Array of mid-sagittal slices. @@ -107,23 +156,23 @@ def recon_cc_surf_measures_multi( List of fractions for anatomical subdivisions. subdivision_method : {SubdivisionMethod} Method for contour subdivision. - contour_smoothing : float + contour_smoothing : int Gaussian sigma for contour smoothing. subject_dir : SubjectDirectory The SubjectDirectory object managing file names in the subject directory. - vox_size : 3-tuple of floats, optional - Voxel size in millimeters (x, y, z). + vox_size : 3-tuple of floats + LIA-oriented voxel size in millimeters (x, y, z). vox2ras_tkr : np.ndarray, optional Voxel to RAS tkr-space transformation matrix. Returns ------- - list + list of CCMeasuresDict List of slice processing results. - list[concurrent.futures.Future] + list of concurrent.futures.Future List of background IO processes. """ - slice_results = [] + slice_cc_measures: list[CCMeasuresDict] = [] io_futures = [] if subdivision_method == "angular" and not np.allclose(np.diff(subdivisions), np.diff(subdivisions)[0]): @@ -132,7 +181,8 @@ def recon_cc_surf_measures_multi( f"but got: {subdivisions}. No measures are computed.", ) - _each_slice = partial(recon_cc_surf_measure, + _each_slice = partial( + recon_cc_surf_measure, segmentation, ac_coords=ac_coords, pc_coords=pc_coords, @@ -140,52 +190,44 @@ def recon_cc_surf_measures_multi( subdivisions=subdivisions, subdivision_method=subdivision_method, contour_smoothing=contour_smoothing, - vox_size=vox_size[0], + vox_size=vox_size, ) # Process multiple slices or specific slice if slice_selection == "middle": num_slices = 1 # Process only the middle slice - slice_iterator = [segmentation.shape[0] // 2] + slices_to_recon = [segmentation.shape[0] // 2] start_slice = segmentation.shape[0] // 2 elif slice_selection == "all": num_slices = segmentation.shape[0] start_slice = 0 end_slice = segmentation.shape[0] - slice_iterator = range(start_slice, end_slice) + slices_to_recon = range(start_slice, end_slice) else: # specific slice number num_slices = 1 - slice_iterator = [int(slice_selection)] + slices_to_recon = [int(slice_selection)] start_slice = int(slice_selection) - it_affine = map(partial(create_slice_affine, upright_affine, fsaverage_middle=FSAVERAGE_MIDDLE), slice_iterator) + _gen_fsavg2slice_vox2vox = partial(create_sag_slice_vox2vox, fsaverage_middle=FSAVERAGE_MIDDLE) + per_slice_vox2ras = fsavg_vox2ras @ np.stack(list(map(_gen_fsavg2slice_vox2vox, slices_to_recon)), axis=0) - iterator = process_executor().map(_each_slice, iter(slice_iterator), it_affine, chunksize=1) + per_slice_recon = process_executor().map(_each_slice, slices_to_recon, per_slice_vox2ras, chunksize=1) cc_mesh = CCMesh(num_slices=num_slices) cc_mesh.set_acpc_coords(ac_coords, pc_coords) - cc_mesh.set_resolution(vox_size[0]) - - def _yield_iterator(): - for _slice_idx in slice_iterator: - try: - yield _slice_idx, *next(iterator) - except Exception as e: - logger.error(f"Slice {_slice_idx} failed with error: {e}") - logger.exception(e) - except StopIteration: - logger.error(f"Unexpectedly skipping slice {_slice_idx} in CC surfaces.") - return - - for i, (slice_idx, result, contour_with_thickness, *endpoint_idxs) in enumerate(_yield_iterator()): - # insert + cc_mesh.set_resolution(vox_size) + + for i, (slice_idx, _results) in enumerate(zip(slices_to_recon, per_slice_recon, strict=True)): progress = f" ({i+1} of {num_slices})" if num_slices > 1 else "" logger.info(f"Calculating CC measurements for slice {slice_idx+1}{progress}") - cc_mesh.add_contour(start_slice-slice_idx, *contour_with_thickness, start_end_idx=endpoint_idxs) - if result is None: - continue + cc_measures, contour_in_as_space_and_thickness, endpoint_idxs = _results + contour_in_as_space, thickness_values = np.split(contour_in_as_space_and_thickness, (2,), axis=1) + cc_mesh.add_contour(start_slice-slice_idx, contour_in_as_space, thickness_values[:, 0], start_end_idx=endpoint_idxs) + if cc_measures is None: + # this should not happen, but just in case + logger.warning(f"Slice index {slice_idx+1}{progress} returned result `None`") - slice_results.append(result) + slice_cc_measures.append(cc_measures) if logger.getEffectiveLevel() <= logging.INFO and subject_dir.has_attribute("cc_qc_image"): qc_img = subject_dir.filename_by_attribute("cc_qc_image") @@ -201,13 +243,13 @@ def _yield_iterator(): thread_executor().submit( plot_contours, transformed=midslices[current_slice_in_volume:current_slice_in_volume+1], - split_contours=result["split_contours"], - midline_equidistant=result["midline_equidistant"], - levelpaths=result["levelpaths"], + split_contours=cc_measures["split_contours"], + midline_equidistant=cc_measures["midline_equidistant"], + levelpaths=cc_measures["levelpaths"], output_path=qc_img, ac_coords=ac_coords, pc_coords=pc_coords, - vox_size=vox_size[0], + vox_size=vox_size, title=f"CC Subsegmentation by {subdivision_method} (Slice {slice_idx})", ) ) @@ -219,14 +261,14 @@ def _yield_iterator(): template_dir.mkdir(parents=True, exist_ok=True) logger.info("Saving template files (contours.txt, thickness_values.txt, " f"thickness_measurement_points.txt) to {template_dir}") - for fut in [ + io_futures.extend([ thread_executor().submit(cc_mesh.save_contours, template_dir / "contours.txt"), thread_executor().submit(cc_mesh.save_thickness_values, template_dir / "thickness_values.txt"), - thread_executor().submit(cc_mesh.save_thickness_measurement_points, - template_dir / "thickness_measurement_points.txt"), - ]: - if fut.exception(): - logger.exception(fut.exception()) + thread_executor().submit( + cc_mesh.save_thickness_measurement_points, + template_dir / "thickness_measurement_points.txt", + ), + ]) mesh_outputs = ("html", "mesh", "thickness_overlay", "surf", "thickness_image") if len(cc_mesh.contours) > 1 and any(subject_dir.has_attribute(f"cc_{n}") for n in mesh_outputs): @@ -246,7 +288,7 @@ def _yield_iterator(): logger.info(f"Saving vtk file to {vtk_file_path}") io_futures.append(thread_executor().submit(cc_mesh.write_vtk, vtk_file_path)) - cc_mesh.to_fs_coordinates(vox2ras_tkr=vox2ras_tkr, vox_size=vox_size) + cc_mesh.to_fs_coordinates(vox2ras_tkr=vox2ras_tkr) if subject_dir.has_attribute("cc_thickness_overlay"): overlay_file_path = subject_dir.filename_by_attribute("cc_thickness_overlay") logger.info(f"Saving overlay file to {overlay_file_path}") @@ -265,25 +307,25 @@ def _yield_iterator(): cc_mesh.snap_cc_picture(thickness_image_path) - if not slice_results: + if not slice_cc_measures: logger.error("Error: No valid slices were found for postprocessing") raise ValueError("No valid slices were found for postprocessing") - return slice_results, io_futures + return slice_cc_measures, io_futures def recon_cc_surf_measure( - segmentation: np.ndarray, + segmentation: np.ndarray[tuple[int, int], np.integer], slice_idx: int, - affine: np.ndarray, - ac_coords: np.ndarray, - pc_coords: np.ndarray, + affine: AffineMatrix4x4, + ac_coords: Vector2d, + pc_coords: Vector2d, num_thickness_points: int, subdivisions: list[float], subdivision_method: SubdivisionMethod, - contour_smoothing: float, - vox_size: float, -) -> tuple[dict[str, float | int | np.ndarray | list[float]], np.ndarray, int, int]: + contour_smoothing: int, + vox_size: tuple[float, float, float], +) -> tuple[CCMeasuresDict, np.ndarray, tuple[int, int]]: """Reconstruct surfaces and compute measures for a single slice for the corpus callosum. Parameters @@ -292,11 +334,11 @@ def recon_cc_surf_measure( 3D segmentation array. slice_idx : int Index of the slice to process. - affine : np.ndarray + affine : AffineMatrix4x4 4x4 affine transformation matrix. - ac_coords : np.ndarray + ac_coords : np.ndarray of shape (2,) and type float Anterior commissure coordinates. - pc_coords : np.ndarray + pc_coords : np.ndarray of shape (2,) and type float Posterior commissure coordinates. num_thickness_points : int Number of points for thickness estimation. @@ -304,36 +346,19 @@ def recon_cc_surf_measure( List of fractions for anatomical subdivisions. subdivision_method : SubdivisionMethod Method for contour subdivision ('shape', 'vertical', 'angular', or 'eigenvector'). - contour_smoothing : float + contour_smoothing : int Gaussian sigma for contour smoothing. - vox_size : float - Voxel size in millimeters. + vox_size : triplet of floats + LIA-oriented voxel size in millimeters. Returns ------- - measures : dict - Dictionary containing measurements if successful, including: - - - cc_index : float - Corpus callosum shape index. - - circularity : float - Shape circularity measure. - - areas : np.ndarray - Areas of subdivided regions. - - midline_length : float - Length along the midline. - - thickness : np.ndarray - Array of thickness measurements. - - curvature : np.ndarray - Array of curvature measurements. - - thickness_profile : list[float] - Thickness measurements along the contour. - - total_area : float - Total area of the CC. - - total_perimeter : float - Total perimeter length. - - split_contours : list[np.ndarray] - Subdivided contour segments. - - midline_equidistant : np.ndarray - Equidistant points along midline. - - levelpaths : list[np.ndarray] - Paths for thickness measurements. - - thickness_measurement_points : np.ndarray - Points where thickness was measured. - - slice_index : int - Index of the processed slice. + measures : CCMeasuresDict + Dictionary containing measurements if successful. contour_with_thickness : np.ndarray Contour points with thickness information. - anterior_endpoint_index : int - Index of the anterior endpoint on the contour. - posterior_endpoint_index : int - Index of the posterior endpoint on the contour. + endpoint_indices : paor of ints + Indices of the anterior and posterior endpoints on the contour. Raises ------ @@ -348,74 +373,72 @@ def recon_cc_surf_measure( 3. Calculates thickness profile using Laplace equation. 4. Computes shape metrics and subdivisions. 5. Generates visualization data. - """ - cc_mask_slice: npt.NDArray[bool] = segmentation[slice_idx] == CC_LABEL + cc_mask_slice: np.ndarray[tuple[int, int], np.dtype[bool]] = np.equal(segmentation[slice_idx], CC_LABEL) if not np.any(cc_mask_slice): raise ValueError(f"No CC found in slice {slice_idx}") - - contour, *endpoint_idxs = get_endpoints( + contour, endpoint_idxs = get_endpoints( cc_mask_slice, ac_coords, pc_coords, - vox_size, + (vox_size[1], vox_size[2]), return_coordinates=False, contour_smoothing=contour_smoothing, ) - contour_1mm = convert_to_ras(contour, affine) + contour_ras = convert_to_ras(contour, affine) - midline_len, thickness, curvature, midline_equi, levelpaths, contour_with_thickness, *endpoint_idxs = cc_thickness( - contour_1mm.T, - *endpoint_idxs, + endpoint_idxs: tuple[int, int] + contour_with_thickness: np.ndarray[tuple[int, Literal[3]], np.floating] + midline_len, thickness, curvature, midline_equi, levelpaths, contour_with_thickness, endpoint_idxs = cc_thickness( + contour_ras[1:].T, + endpoint_idxs, n_points=num_thickness_points, ) - - thickness_profile = [ - np.sum(np.sqrt(np.diff(np.array(levelpath[:,:2]), axis=0)**2), axis=0) - for levelpath in levelpaths - ] - thickness_profile = np.linalg.norm(np.array(thickness_profile),axis=1) - - acpc_contour_coords = contour_1mm[:, list(endpoint_idxs)].T - contour_acpc, ac_pt_acpc, pc_pt_acpc, rotate_back_acpc = transform_to_acpc_standard( - contour_1mm, - *acpc_contour_coords, + # thickness values in contour_with_thickness is not equally sampled, different shape + # to compute length of paths: diff between consecutive points (N-1, 2) => norm (N-1,) => sum (1,) + thickness_profile = np.stack([np.sum(np.linalg.norm(np.diff(x[:, :2], axis=0), axis=1)) for x in levelpaths]) + + acpc_contour_coords_ras = contour_ras[:, list(endpoint_idxs)].T + contour_in_acpc_space, ac_pt_acpc, pc_pt_acpc, rotate_back_acpc = transform_to_acpc_standard( + contour_ras[1:], + *acpc_contour_coords_ras[:, 1:], ) - cc_index = calculate_cc_index(contour_acpc) + cc_index = calculate_cc_index(contour_in_acpc_space) # Apply different subdivision methods based on user choice if subdivision_method == "shape": - areas, split_contours = subsegment_midline_orthogonal(midline_equi, subdivisions, contour_1mm, plot=False) - split_contours = [transform_to_acpc_standard(split_contour, *acpc_contour_coords)[0] + _subdivisions = np.asarray(subdivisions) + areas, split_contours = subsegment_midline_orthogonal(midline_equi, _subdivisions, contour_ras[1:], plot=False) + split_contours = [transform_to_acpc_standard(split_contour, *acpc_contour_coords_ras[:, 1:])[0] for split_contour in split_contours] elif subdivision_method == "vertical": - areas, split_contours = subdivide_contour(contour_acpc, subdivisions, plot=False) + areas, split_contours = subdivide_contour(contour_in_acpc_space, subdivisions, plot=False) elif subdivision_method == "angular": if not np.allclose(np.diff(subdivisions), np.diff(subdivisions)[0]): raise ValueError( f"Angular subdivision method (Hampel) only supports equidistant subdivision, " f"but got: {subdivisions}. No measures are computed.", ) - areas, split_contours = hampel_subdivide_contour(contour_acpc, num_rays=len(subdivisions), plot=False) + areas, split_contours = hampel_subdivide_contour(contour_in_acpc_space, num_rays=len(subdivisions), plot=False) elif subdivision_method == "eigenvector": - pt0, pt1 = get_primary_eigenvector(contour_acpc) - contour_eigen, _, _, rotate_back_eigen = transform_to_acpc_standard(contour_acpc, pt0, pt1) + pt0, pt1 = get_primary_eigenvector(contour_in_acpc_space) + contour_eigen, _, _, rotate_back_eigen = transform_to_acpc_standard(contour_in_acpc_space, pt0, pt1) ac_pt_eigen, _, _, _ = transform_to_acpc_standard(ac_pt_acpc[:, None], pt0, pt1) ac_pt_eigen = ac_pt_eigen[:, 0] areas, split_contours = subdivide_contour(contour_eigen, subdivisions, oriented=True, hline_anchor=ac_pt_eigen) split_contours = [rotate_back_eigen(split_contour) for split_contour in split_contours] total_area = np.sum(areas) - total_perimeter = np.sum(np.sqrt(np.sum((np.diff(contour_1mm, axis=0))**2, axis=1))) + total_perimeter = np.sum(np.sqrt(np.sum((np.diff(contour_ras[:, 1:], axis=0))**2, axis=1))) circularity = 4 * np.pi * total_area / (total_perimeter**2) # Transform split contours back to original space split_contours = [rotate_back_acpc(split_contour) for split_contour in split_contours] - measures = { + measures: CCMeasuresDict = { "cc_index": cc_index, "circularity": circularity, - "areas": areas, + "areas": np.asarray(areas), "midline_length": midline_len, "thickness": thickness, "curvature": curvature, @@ -427,7 +450,7 @@ def recon_cc_surf_measure( "levelpaths": levelpaths, "slice_index": slice_idx } - return measures, contour_with_thickness, *endpoint_idxs + return measures, contour_with_thickness, endpoint_idxs def vectorized_line_test(coords_x: np.ndarray, coords_y: np.ndarray, @@ -463,8 +486,6 @@ def vectorized_line_test(coords_x: np.ndarray, coords_y: np.ndarray, return cross_products > 0 - - def get_unique_contour_points(split_contours: list[tuple[np.ndarray, np.ndarray]]) -> list[np.ndarray]: """Get unique contour points from the split contours. @@ -535,6 +556,8 @@ def make_subdivision_mask( split_contours : list[tuple[np.ndarray, np.ndarray]] List of contours defining the subdivisions. Each contour is a tuple of x and y coordinates. + vox_size : triplet of floats + Returns ------- @@ -623,4 +646,4 @@ def check_area_changes(contours: list[np.ndarray], threshold: float = 0.3) -> bo for i, p in zip(indices, percent_change, strict=True)) ) return False - return True \ No newline at end of file + return True diff --git a/CorpusCallosum/shape/subsegment_contour.py b/CorpusCallosum/shape/subsegment_contour.py index c1e37359..d1a7d112 100644 --- a/CorpusCallosum/shape/subsegment_contour.py +++ b/CorpusCallosum/shape/subsegment_contour.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Callable -from typing import TypeVar +from typing import Literal, TypeVar import matplotlib.pyplot as plt import numpy as np @@ -104,7 +104,14 @@ def calc_subsegment_areas(split_contours: list[npt.NDArray[_TS]]) -> npt.NDArray return np.ediff1d(np.asarray(areas)[::-1], to_end=areas[-1]) -def subsegment_midline_orthogonal(midline, area_weights, contour, plot=True, ax=None, extremes=None): +def subsegment_midline_orthogonal( + midline: np.ndarray[tuple[int, Literal[2]], np.dtype[float]], + area_weights: np.ndarray[tuple[int], np.dtype[float]], + contour: np.ndarray[tuple[Literal[2], int], np.dtype[_TS]], + plot: bool = True, + ax=None, + extremes=None, +): """Subsegment contour orthogonally to the midline based on area weights. Parameters @@ -114,7 +121,7 @@ def subsegment_midline_orthogonal(midline, area_weights, contour, plot=True, ax= area_weights : np.ndarray Array of weights for area-based subdivision. contour : np.ndarray - Array of shape (2, M) containing contour points. + Array of shape (2, M) containing contour points in as space. plot : bool, optional Whether to plot the results, by default True. ax : matplotlib.axes.Axes, optional @@ -130,6 +137,7 @@ def subsegment_midline_orthogonal(midline, area_weights, contour, plot=True, ax= List of contour arrays for each subsegment. """ + # FIXME: why does this code return subsegments that include all previous segments? # get points after midline length of splits # get vertex closest to midline end @@ -149,6 +157,44 @@ def subsegment_midline_orthogonal(midline, area_weights, contour, plot=True, ax= split_contours = [contour] + # FIXME: double loop should be vectorized, see commented code below for an initial attempt (not tested) + # also, finding intersections can be done more efficiently, instead of solving linear system for each segment + # we could just look for changes in the sign of cross products + # mid_to_contour: np.ndarray = contour[:, :, None] - split_points[:, None] + # mid_to_contour_length = np.linalg.norm(mid_to_contour, axis=0) + # mid_to_contour_norm = mid_to_contour / mid_to_contour_length[None] + # sin_theta = mid_to_contour_norm[0] * edge_ortho_vectors[1] - mid_to_contour_norm[1] * edge_ortho_vectors[0] + # index_on_contour, index_on_segment = np.where(sin_theta[:-1] * sin_theta[1:] < 0) + # sin_theta_x = sin_theta[index_on_segment] + # cos_theta_x = np.sqrt(1 - sin_theta_x * sin_theta_x) + # rot_mat = np.array([[cos_theta_x, -sin_theta_x], [sin_theta_x, cos_theta_x]]) + # # rotate mid_to_contour by sin_theta + # _mid_to_intersection = rot_mat.transpose(0, -1) @ mid_to_contour[:, None, (index_on_contour, index_on_segment)] + # mid_to_intersection = cos_theta_x * _mid_to_intersection[:, 0, :] + # intersection_points = split_points[:, index_on_segment] + mid_to_intersection + # mid_to_intersection_length = np.linalg.norm(mid_to_intersection, axis=0) + # + # + # for segment_idx in range(split_points.shape[1]): + # mask = index_on_segment == segment_idx + # if any(mask): + # # first_index and second_index are the indices on the contour + # # _first_index and _second_index are the indices on the intersection_points of this segment + # _first_index, _second_index, *_ = np.argsort(mid_to_intersection_length[mask]) + # first_index, second_index = index_on_contour[mask][[_first_index, _second_index]] + # if first_index > second_index: + # first_index, second_index = second_index, first_index + # _first_index, _second_index = _second_index, _first_index + # # connect first and second half + # start_to_cutoff = np.hstack( + # ( + # contour[:, :first_index + 1], # includes first_index + # intersection_points[:, mask][:, [_first_index, _second_index]], + # contour[:, second_index + 1 :], # excludes second_index + # ) + # ) + # split_contours.append(start_to_cutoff) + for pt_idx, split_point in enumerate(split_points): intersections = [] for i in range(contour.shape[1] - 1): diff --git a/CorpusCallosum/shape/thickness.py b/CorpusCallosum/shape/thickness.py index a7f10d1a..15b99503 100644 --- a/CorpusCallosum/shape/thickness.py +++ b/CorpusCallosum/shape/thickness.py @@ -53,6 +53,7 @@ def convert_to_ras(contour: np.ndarray, vox2ras_matrix: np.ndarray, get_paramete def convert_to_ras(contour: np.ndarray, vox2ras_matrix: np.ndarray, get_parameters: Literal[True]) \ -> tuple[np.ndarray, bool, bool, bool]: ... + def convert_to_ras( contour: np.ndarray, vox2ras_matrix: np.ndarray, @@ -71,8 +72,8 @@ def convert_to_ras( Returns ------- - contour : p.ndarray - Transformed contour coordinates. + contour : np.ndarray + Transformed contour coordinates of shape (3, N). anterior_reversed : bool Only if return_parameters is True, whether anterior axis was reversed. superior_reversed : bool @@ -80,11 +81,11 @@ def convert_to_ras( swap_axes : bool Only if return_parameters is True, whether axes were swapped. """ - # converting to AS (no left-right dimension), out of plane movement is ignores, + # converting to AS (no left-right dimension), out of plane movement is ignored, # so we only do scaling, axes swapping and flipping - no rotation # translation is ignored if contour.shape[0] == 2: - # get only axis swaps + # get only axis swaps from the rotation part of the vox2ras matrix axis_swaps = np.round(vox2ras_matrix[:3, :3], 0) permutation = np.argwhere(axis_swaps != 0)[:, 1] assert len(permutation) == 3 @@ -92,8 +93,8 @@ def convert_to_ras( idx_superior = np.argwhere(permutation == 2) idx_anterior = np.argwhere(permutation == 1) - swap_axes = idx_anterior > idx_superior - if swap_axes: + # swap axes if indicated from vox2ras + if swap_axes := idx_anterior > idx_superior: # swap anterior and superior contour = contour[[1, 0]] @@ -113,6 +114,9 @@ def convert_to_ras( # voxel * vox_size = mm contour = (contour.T * scaling[1:]).T + # append a 0-R coordinate + contour = np.concatenate([np.zeros((1, contour.shape[1])), contour], axis=0) + if return_parameters: return contour, anterior_reversed, superior_reversed, swap_axes else: @@ -199,49 +203,69 @@ def find_closest_edge(point, contour): return np.argmin(distances) +@overload +def insert_point_with_thickness( + contour_in_as_space: np.ndarray, + contour_thickness: np.ndarray, + point: np.ndarray, + thickness_value: float, + return_index: Literal[False] = False, +) -> tuple[np.ndarray, np.ndarray]: ... + + +@overload +def insert_point_with_thickness( + contour_in_as_space: np.ndarray, + contour_thickness: np.ndarray, + point: np.ndarray, + thickness_value: float, + return_index: Literal[True], +) -> tuple[np.ndarray, np.ndarray, int] | list[np.ndarray, np.ndarray]: + ... + + def insert_point_with_thickness( - contour_with_thickness: list[np.ndarray], + contour_in_as_space: np.ndarray, + contour_thickness: np.ndarray, point: np.ndarray, thickness_value: float, - get_index: bool = False -) -> tuple[list[np.ndarray], int] | list[np.ndarray]: - """Insert a point and its thickness value into the contour. + return_index: bool = False +) -> tuple[np.ndarray, np.ndarray, int] | tuple[np.ndarray, np.ndarray]: + """Inserts a point and its thickness value into the contour. Parameters ---------- - contour_with_thickness : list[np.ndarray] - List containing [contour_points, thickness_values]. + contour_in_as_space : np.ndarray + Array of coordinates of the contour in AS space, shape (N, 2). + contour_thickness : np.ndarray + Array of thickness values of the contour, shape (N,). point : np.ndarray 2D point to insert, shape (2,). thickness_value : float Thickness value corresponding to the point. - get_index : bool, optional + return_index : bool, default=False If True, return the index where point was inserted, by default False. Returns ------- - tuple[list[np.ndarray], int] or list[np.ndarray] - If get_index is True: - - Updated contour_with_thickness. - - Index where point was inserted. - If get_index is False: - - Updated contour_with_thickness. + contour_in_as_space : np.ndarray + Updated contour of shape (N+1, 2). + contour_thickness : np.ndarray + Updated thickness values of shape (N+1,). + insertion_index : int + The index, where the point was inserted (only if return_index is True). """ # Find closest edge for the point - edge_idx = find_closest_edge(point, contour_with_thickness[0]) + edge_idx = find_closest_edge(point, contour_in_as_space) # Insert point between edge endpoints - contour_with_thickness[0] = np.insert( - contour_with_thickness[0], edge_idx + 1, point, axis=0 - ) - contour_with_thickness[1] = np.insert( - contour_with_thickness[1], edge_idx + 1, thickness_value - ) + contour_in_as_space = np.insert(contour_in_as_space, edge_idx + 1, point, axis=0) + contour_thickness = np.insert(contour_thickness, edge_idx + 1, thickness_value) - if get_index: - return contour_with_thickness, edge_idx + 1 + if return_index: + return contour_in_as_space, contour_thickness, edge_idx + 1 else: - return contour_with_thickness + return contour_in_as_space, contour_thickness def make_mesh_from_contour( @@ -294,20 +318,17 @@ def make_mesh_from_contour( def cc_thickness( contour_2d: np.ndarray, - anterior_endpoint_idx: int, - posterior_endpoint_idx: int, + endpoint_idx: tuple[int, int], n_points: int = 100 -) -> tuple[float, float, float, np.ndarray, list[np.ndarray], list[np.ndarray], int, int]: +) -> tuple[float, float, float, np.ndarray, list[np.ndarray], np.ndarray, tuple[int, int]]: """Calculate corpus callosum thickness using Laplace equation. Parameters ---------- contour_2d : np.ndarray Array of shape (N, 2) containing contour points. - anterior_endpoint_idx : int - Index of anterior endpoint in contour. - posterior_endpoint_idx : int - Index of posterior endpoint in contour. + endpoint_idx : pair of ints + Indices of anterior and posterior endpoints in contour. n_points : int, optional Number of points for thickness measurement, by default 100. @@ -320,15 +341,13 @@ def cc_thickness( curvature : float Mean absolute curvature in degrees. midline_equidistant : np.ndarray - Equidistant points along the midline. + Equidistant points along the midline in same space as contour2d. levelpaths : list[np.ndarray] - Level paths for thickness measurement. - contour_with_thickness : list[np.ndarray] - Contour coordinates with thickness information. - anterior_endpoint_idx : int - Updated index of anterior endpoint. - posterior_endpoint_idx : int - Updated index of posterior endpoint. + Level paths for thickness measurement in same space as contour2d. + contour_with_thickness : np.ndarray + Contour coordinates with thickness information in same space as contour2d of shape (N+2, 3). + endpoint_indices : pair of ints + Pair of updated indices of anterior and posterior endpoint. Notes ----- @@ -338,22 +357,23 @@ def cc_thickness( 3. Solving Laplace equation to get level sets 4. Computing thickness along level sets """ + anterior_endpoint_idx, posterior_endpoint_idx = endpoint_idx - # standardize contour indices, to get consistent levelpath directions + # standardize contour indices to start at anterior_endpoint_idx, to get consistent levelpath directions contour_2d, anterior_endpoint_idx, posterior_endpoint_idx = set_contour_zero_idx( contour_2d, anterior_endpoint_idx, anterior_endpoint_idx, posterior_endpoint_idx, ) - mesh_points, mesh_trias = make_mesh_from_contour(contour_2d) + mesh_points_contour_space, mesh_trias = make_mesh_from_contour(contour_2d) - # make points 3D by appending z=0 - mesh_points3d = np.append(mesh_points, np.zeros((mesh_points.shape[0], 1)), axis=1) + # make points 3D by appending z=0, asz space therefore is the contour space (usually AS space) with a zero z-dim + mesh_points_asz = np.append(mesh_points_contour_space, np.zeros((mesh_points_contour_space.shape[0], 1)), axis=1) # compute poisson with suppress_stdout(): - tria = TriaMesh(mesh_points3d, mesh_trias) + tria_asz = TriaMesh(mesh_points_asz, mesh_trias) # extract boundary curve - bdr = np.array(tria.boundary_loops()[0]) + bdr = np.array(tria_asz.boundary_loops()[0]) # find index of endpoints in bdr list iidx1 = np.where(bdr == anterior_endpoint_idx)[0][0] @@ -368,67 +388,69 @@ def cc_thickness( dcond[iidx1 + 1 : iidx2] = -1 # Extract path - fem = Solver(tria) + fem = Solver(tria_asz) vfunc = fem.poisson(0, (bdr, dcond)) - midline_equidistant, midline_length = tria.level_path(vfunc, level=0., n_points=n_points + 2) - midline_equidistant = midline_equidistant[:, :2] + midline_length: float + midline_equidistant_asz, midline_length = tria_asz.level_path(vfunc, level=0., n_points=n_points + 2) + midline_equidistant_contour_space: np.ndarray = midline_equidistant_asz[:, :2] - gf = compute_rotated_f(tria, vfunc) + gf = compute_rotated_f(tria_asz, vfunc) # interpolate midline to get levels to evaluate - level_of_rotated_laplace = scipy.interpolate.griddata( - tria.v[:, 0:2], gf, midline_equidistant[:, 0:2], method="cubic", + level_of_rotated_laplace_contour_space = scipy.interpolate.griddata( + tria_asz.v[:, 0:2], gf, midline_equidistant_asz[:, 0:2], method="cubic", ) # get levels to evaluate - levelpaths = [] + levelpaths_contour_space: list[np.ndarray] = [] levelpath_lengths = [] levelpath_tria_idx = [] # now, on the rotated laplace function, sample equally spaced (on midline: level_of_rotated_laplace) levelpaths - contour_with_thickness = [contour_2d.copy(), np.full(contour_2d.shape[0], np.nan)] - for current_level in level_of_rotated_laplace[1:-1]: + contour_thickness = np.full(contour_2d.shape[0], np.nan) + for current_level in level_of_rotated_laplace_contour_space[1:-1]: # levelpath starts at index zero - lvlpath, lvlpath_length, tria_idx = tria.level_path( - gf, current_level, get_tria_idx=True, - ) + levelpath_asz, lvlpath_length, tria_idx = tria_asz.level_path(gf, current_level, get_tria_idx=True) - levelpaths.append(lvlpath) + levelpaths_contour_space.append(levelpath_asz[:, :2]) levelpath_lengths.append(lvlpath_length) levelpath_tria_idx.append(tria_idx) - levelpath_start = lvlpath[0, :2] - levelpath_end = lvlpath[-1, :2] + levelpath_start = levelpath_asz[0, :2] + levelpath_end = levelpath_asz[-1, :2] - contour_with_thickness, inserted_idx_start = insert_point_with_thickness( - contour_with_thickness, levelpath_start, lvlpath_length, get_index=True, + contour_2d, contour_thickness, inserted_idx_start = insert_point_with_thickness( + contour_2d, contour_thickness, levelpath_start, lvlpath_length, return_index=True, ) - contour_with_thickness, inserted_idx_end = insert_point_with_thickness( - contour_with_thickness, levelpath_end, lvlpath_length, get_index=True, - ) - - # keep track of start and end indices + # keep track of start index if inserted_idx_start <= anterior_endpoint_idx: anterior_endpoint_idx += 1 - if inserted_idx_end <= anterior_endpoint_idx: - anterior_endpoint_idx += 1 - if inserted_idx_start >= posterior_endpoint_idx: posterior_endpoint_idx += 1 + + contour_2d, contour_thickness, inserted_idx_end = insert_point_with_thickness( + contour_2d, contour_thickness, levelpath_end, lvlpath_length, return_index=True, + ) + # keep track of end index + if inserted_idx_end <= anterior_endpoint_idx: + anterior_endpoint_idx += 1 if inserted_idx_end >= posterior_endpoint_idx: posterior_endpoint_idx += 1 + contour_2d_with_thickness = np.concatenate([contour_2d, contour_thickness[:, None]], axis=1) + # get curvature of path3d_resampled - curvature = compute_curvature(midline_equidistant) - out_curvature = np.abs(np.degrees(np.mean(curvature))) / len(curvature) + curvature = compute_curvature(midline_equidistant_contour_space) + mean_curvature: float = np.abs(np.degrees(np.mean(curvature))).item() / len(curvature) + mean_thickness: float = np.mean(levelpath_lengths).item() + endpoints: tuple[int, int] = (anterior_endpoint_idx, posterior_endpoint_idx) return ( midline_length, - np.mean(levelpath_lengths), - out_curvature, - midline_equidistant, - levelpaths, - contour_with_thickness, - anterior_endpoint_idx, - posterior_endpoint_idx, + mean_thickness, + mean_curvature, + midline_equidistant_contour_space, + levelpaths_contour_space, + contour_2d_with_thickness, + endpoints, ) diff --git a/CorpusCallosum/transforms/segmentation.py b/CorpusCallosum/transforms/segmentation.py index 2fcf41da..9d2c7268 100644 --- a/CorpusCallosum/transforms/segmentation.py +++ b/CorpusCallosum/transforms/segmentation.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +from typing import Literal import numpy as np from monai.transforms import MapTransform, RandomizableTransform @@ -68,36 +70,45 @@ def __call__(self, data: dict) -> dict: """ d = dict(data) - if 'AC_center_original' not in d: - d['AC_center_original'] = d['AC_center'].copy() - if 'PC_center_original' not in d: - d['PC_center_original'] = d['PC_center'].copy() + if "AC_center_original" not in d: + d["AC_center_original"] = d["AC_center"].copy() + if "PC_center_original" not in d: + d["PC_center_original"] = d["PC_center"].copy() if self.random_translate > 0: random_translate = np.random.randint(-self.random_translate, self.random_translate, size=2) else: random_translate = (0,0,0) - - for key in self.keys: - if key not in d.keys() and self.allow_missing_keys: - continue - pc_center = d['PC_center'] - ac_center = d['AC_center'] - - ac_pc_bottomleft = (np.min([ac_center[1], pc_center[1]]).astype(int), - np.min([ac_center[2], pc_center[2]]).astype(int)) - ac_pc_topright = (np.max([ac_center[1], pc_center[1]]).astype(int), - np.max([ac_center[2], pc_center[2]]).astype(int)) + pc_center = d["PC_center"] + ac_center = d["AC_center"] + + ac_pc = np.stack([ac_center, pc_center], axis=0) + + ac_pc_bottomleft = np.min(ac_pc, axis=0).astype(int) + ac_pc_topright = np.max(ac_pc, axis=0).astype(int) + + VoxPadType = np.ndarray[tuple[Literal[2]], np.dtype[int]] + voxel_padding: VoxPadType = np.round(self.padding_mm / d["res"]).astype(int) + + crop_left = ac_pc_bottomleft[1] - int(voxel_padding[0] * 1.5) + random_translate[0] + crop_right = ac_pc_topright[1] + voxel_padding[0] // 2 + random_translate[0] + crop_top = ac_pc_bottomleft[2] - voxel_padding[1] + random_translate[1] + crop_bottom = ac_pc_topright[2] + voxel_padding[1] + random_translate[1] + + keys_to_process = [key for key in self.keys if key in d.keys()] + + if not self.allow_missing_keys and set(keys_to_process) != set(self.keys): + raise ValueError("Some keys are missing in the data dictionary.") - voxel_padding = round(self.padding_mm / d['res']) + if len(keys_to_process) == 0: + logging.getLogger(__name__).warning("No keys to process.") + return d - crop_left = ac_pc_bottomleft[0]-int(voxel_padding*1.5)+random_translate[0] - crop_right = ac_pc_topright[0]+voxel_padding//2+random_translate[0] - crop_top = ac_pc_bottomleft[1]-voxel_padding+random_translate[1] - crop_bottom = ac_pc_topright[1]+voxel_padding+random_translate[1] + first_key = keys_to_process[0] + d["to_pad"] = crop_left, d[first_key].shape[2] - crop_right, crop_top, d[first_key].shape[3] - crop_bottom - d['to_pad'] = crop_left, d[key].shape[2]-crop_right, crop_top, d[key].shape[3]-crop_bottom + for key in keys_to_process: d[key] = d[key][:, :, crop_left:crop_right, crop_top:crop_bottom] return d @@ -154,16 +165,16 @@ def __call__(self, data: dict) -> dict: d = super().__call__(data) # Get the crop coordinates that were used - pad_left, pad_right, pad_top, pad_bottom = d['to_pad'] + pad_left, pad_right, pad_top, pad_bottom = d["to_pad"] # Adjust AC and PC center coordinates based on cropping - if 'AC_center' in d: - d['AC_center'][1] = d['AC_center_original'][1] - pad_left.item() - d['AC_center'][2] = d['AC_center_original'][2] - pad_top.item() + if "AC_center" in d: + d["AC_center"][1] = d["AC_center_original"][1] - pad_left.item() + d["AC_center"][2] = d["AC_center_original"][2] - pad_top.item() - if 'PC_center' in d: - d['PC_center'][1] = d['PC_center_original'][1] - pad_left.item() - d['PC_center'][2] = d['PC_center_original'][2] - pad_top.item() + if "PC_center" in d: + d["PC_center"][1] = d["PC_center_original"][1] - pad_left.item() + d["PC_center"][2] = d["PC_center_original"][2] - pad_top.item() return d diff --git a/CorpusCallosum/utils/mapping_helpers.py b/CorpusCallosum/utils/mapping_helpers.py index 821d4432..37483bdf 100644 --- a/CorpusCallosum/utils/mapping_helpers.py +++ b/CorpusCallosum/utils/mapping_helpers.py @@ -7,7 +7,7 @@ from scipy.ndimage import affine_transform from CorpusCallosum.data.constants import CC_LABEL, FORNIX_LABEL -from FastSurferCNN.utils import logging +from FastSurferCNN.utils import logging, AffineMatrix4x4 from FastSurferCNN.utils.parallel import thread_executor logger = logging.get_logger(__name__) @@ -201,8 +201,8 @@ def calc_mapping_to_standard_space( def apply_transform_to_volume( orig_image: nib.analyze.SpatialImage, - transform: npt.NDArray[float], - affine: npt.NDArray[float], + vox2vox: AffineMatrix4x4, + affine: AffineMatrix4x4, header: nib.freesurfer.mghformat.MGHHeader | None = None, output_path: str | Path | None = None, output_size: np.ndarray | None = None, @@ -214,14 +214,14 @@ def apply_transform_to_volume( ---------- orig_image : nibabel.analyze.SpatialImage Input volume. - transform : np.ndarray - Transformation matrix to apply. - affine : np.ndarray - Affine matrix for the output image. + vox2vox : np.ndarray + Transformation matrix to apply to the data, this is from input-to-output space. + affine : AffineMatrix4x4, optional + vox2ras matrix of the output image, only relevant if output_path is given. header : nib.freesurfer.mghformat.MGHHeader, optional - Header for the output image, if None will default to orig_image header. + Header for the output image, only relevant if output_path is given, if None will default to orig_image header. output_path : str or Path, optional - Path to save transformed volume. + If output_path is provided, saves the result under this path. output_size : np.ndarray, optional Size of output volume, uses input size by default (`None`). order : int, default=1 @@ -234,23 +234,21 @@ def apply_transform_to_volume( Notes ----- - Uses scipy.ndimage.affine_transform for the transformation. - If output_path is provided, saves the result as a MGH file. + Uses `scipy.ndimage.affine_transform` for the transformation, and inverts vox2vox internally as required by + `affine_transform`. """ if output_size is None: output_size = np.array(orig_image.shape) if header is None: header = orig_image.header - transformed = affine_transform( - orig_image.get_data(), - np.linalg.inv(transform), - output_shape=output_size, - order=order, - ) + # transform / resample the volume with vox2vox, note this needs to be the inverse of input2output vox2vox! + # affine_transform definition is: input_coord = matrix @ output_coord + offset ( == MATRIX_HOM @ output_coord_hom) + # --> output_coord = inv(matrix) @ (input_coord - offset) ( == inv(MATRIX_HOM) @ input_coord_hom) + resampled = affine_transform(orig_image.get_fdata(), np.linalg.inv(vox2vox), output_shape=output_size, order=order) if output_path is not None: logger.info(f"Saving transformed volume to {output_path}") - nib.save(nib.MGHImage(transformed.astype(orig_image.get_data_dtype()), affine, header), output_path) - return transformed + nib.save(nib.MGHImage(resampled.astype(orig_image.get_data_dtype()), affine, header), output_path) + return resampled def make_affine(simpleITKImage: sitk.Image) -> npt.NDArray[float]: @@ -387,7 +385,7 @@ def interpolate_midplane( """ # slice_thickness = 9+slices_to_analyze-1 - # make grid of 9 slices in the fsaverage middle + # make grid of 9 slices in the fsaverage middle # (cube from 123.5,0.5,0.5 to 132.5,255.5,255.5 (incudling end points, 1mm spacing)) x_coords = np.linspace( 124 - slices_to_analyze // 2, diff --git a/CorpusCallosum/utils/visualization.py b/CorpusCallosum/utils/visualization.py index 285e0d84..3b98ab29 100644 --- a/CorpusCallosum/utils/visualization.py +++ b/CorpusCallosum/utils/visualization.py @@ -133,7 +133,7 @@ def plot_contours( output_path: str | Path | None = None, ac_coords: np.ndarray | None = None, pc_coords: np.ndarray | None = None, - vox_size: float | None = None, + vox_size: tuple[float, float, float] | None = None, title: str = "", ) -> None: """Creates a figure of the contours (shape) and the subdivisions of the corpus callosum. @@ -154,8 +154,9 @@ def plot_contours( AC coordinates for visualization (ignore AC on None). pc_coords : np.ndarray, optional PC coordinates for visualization (ignore PC on None). - vox_size : float, optional - Voxel size for scaling. + vox_size : triplet of floats, optional + LIA-oriented voxel size for scaling, optional if none of split_contours, midline_equidistant, or levelpaths are + provided. title : str, default="" Title for the plot. @@ -165,15 +166,18 @@ def plot_contours( If output_path is provided, saves the plot to that location. """ - # scale contour data by vox_size - if split_contours: - split_contours = np.stack(split_contours, axis=0) / vox_size - if midline_equidistant: - midline_equidistant = midline_equidistant / vox_size - if levelpaths: - levelpaths = np.stack(levelpaths, axis=0) / vox_size + if vox_size is None and None in (split_contours, midline_equidistant, levelpaths): + raise ValueError("vox_size must be provided if split_contours, midline_equidistant, or levelpaths are given.") - has_first_plot = bool(split_contours) or bool(ac_coords) or bool(pc_coords) + # convert vox_size from LIA to AS + vox_size_ras = np.asarray([vox_size[0], vox_size[2], vox_size[1]]) if vox_size is not None else None + + # scale contour data by vox_size to convert from AS to AS-aligned voxel space + _split_contours = [] if split_contours is None else [sp / vox_size_ras[1:, None] for sp in split_contours] + _midline_equi = np.zeros((0, 2)) if midline_equidistant is None else midline_equidistant / vox_size_ras[None, 1:] + _levelpaths = [] if levelpaths is None else [lp / vox_size_ras[None, 1:] for lp in levelpaths] + + has_first_plot = not (len(_split_contours) == 0 and ac_coords is None and pc_coords is None) num_plots = 1 + int(has_first_plot) _, ax = plt.subplots(1, num_plots, sharex=True, sharey=True, figsize=(15, 10)) @@ -184,32 +188,36 @@ def plot_contours( if has_first_plot: ax[current_plot].imshow(transformed[transformed.shape[0] // 2], cmap="gray") ax[current_plot].set_title(title) - if split_contours: - for i, this_contour in enumerate(split_contours): + if _split_contours: + for i, this_contour in enumerate(_split_contours): ax[current_plot].fill(this_contour[0, :], -this_contour[1, :], color="steelblue", alpha=0.25) kwargs = {"color": "mediumblue", "linewidth": 0.7, "linestyle": "solid" if i != 0 else "dotted"} ax[current_plot].plot(this_contour[0, :], -this_contour[1, :], **kwargs) - if ac_coords: + if ac_coords is not None: ax[current_plot].scatter(ac_coords[1], ac_coords[0], color="red", marker="x") - if pc_coords: + if pc_coords is not None: ax[current_plot].scatter(pc_coords[1], pc_coords[0], color="blue", marker="x") current_plot += int(has_first_plot) - reference_contour = split_contours[0] ax[current_plot].imshow(transformed[transformed.shape[0] // 2], cmap="gray") - for this_path in levelpaths: + for this_path in _levelpaths: ax[current_plot].plot(this_path[:, 0], -this_path[:, 1], color="brown", linewidth=0.8) ax[current_plot].set_title("Midline & Levelpaths") - ax[current_plot].plot(midline_equidistant[:, 0], -midline_equidistant[:, 1], color="red") - ax[current_plot].plot(reference_contour[0, :], -reference_contour[1, :], color="red", linewidth=0.5) + if _midline_equi.shape[0] > 0: + ax[current_plot].plot(_midline_equi[:, 0], -_midline_equi[:, 1], color="red") + if _split_contours: + reference_contour = _split_contours[0] + ax[current_plot].plot(reference_contour[0, :], -reference_contour[1, :], color="red", linewidth=0.5) padding = 30 for a in ax.flatten(): a.set_aspect("equal", adjustable="box") a.axis("off") - # get bounding box of contours - a.set_xlim(reference_contour[0, :].min() - padding, reference_contour[0, :].max() + padding) - a.set_ylim((-reference_contour[1, :]).max() + padding, (-reference_contour[1, :]).min() - padding) + if _split_contours: + reference_contour = _split_contours[0] + # get bounding box of contours + a.set_xlim(reference_contour[0, :].min() - padding, reference_contour[0, :].max() + padding) + a.set_ylim((-reference_contour[1, :]).max() + padding, (-reference_contour[1, :]).min() - padding) Path(output_path).parent.mkdir(parents=True, exist_ok=True) plt.savefig(output_path, dpi=300, bbox_inches="tight") diff --git a/run_fastsurfer.sh b/run_fastsurfer.sh index 6727bddd..39d5e7ac 100755 --- a/run_fastsurfer.sh +++ b/run_fastsurfer.sh @@ -1117,7 +1117,7 @@ then # generate file names of for the analysis asegdkt_withcc_segfile="$(add_file_suffix "$asegdkt_segfile" "withCC")" asegdkt_withcc_vinn_statsfile="$(add_file_suffix "$asegdkt_vinn_statsfile" "withCC")" - aseg_auto_statsfile="$(basename "$aseg_vinn_statsfile")/aseg.auto.mgz" + aseg_auto_statsfile="$(dirname "$aseg_vinn_statsfile")/aseg.auto.mgz" # note: callosum manedit currently only affects inpainting and not internal FastSurferCC processing (surfaces etc) callosum_seg_manedit="$(add_file_suffix "$callosum_seg" "manedit")" # generate callosum segmentation, mesh, shape and downstream measure files From 7137ba865769c8c0e8d9007a07cff8a04b8073d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Mon, 8 Dec 2025 15:53:11 +0100 Subject: [PATCH 42/68] Fix docstrings and formatting in mesh.py other fixes?! --- .../data/generate_fsaverage_centroids.py | 2 +- CorpusCallosum/fastsurfer_cc.py | 40 +++++++----- CorpusCallosum/localization/inference.py | 14 +++-- CorpusCallosum/shape/mesh.py | 61 +++++++++---------- CorpusCallosum/shape/postprocessing.py | 17 +++--- CorpusCallosum/utils/mapping_helpers.py | 61 +++++++++---------- CorpusCallosum/utils/visualization.py | 14 +++-- 7 files changed, 111 insertions(+), 98 deletions(-) diff --git a/CorpusCallosum/data/generate_fsaverage_centroids.py b/CorpusCallosum/data/generate_fsaverage_centroids.py index 4dd874ac..b1ef7b19 100644 --- a/CorpusCallosum/data/generate_fsaverage_centroids.py +++ b/CorpusCallosum/data/generate_fsaverage_centroids.py @@ -26,7 +26,7 @@ import nibabel as nib import numpy as np -from read_write import convert_numpy_to_json_serializable, calc_ras_centroids_from_seg +from read_write import calc_ras_centroids_from_seg, convert_numpy_to_json_serializable import FastSurferCNN.utils.logging as logging diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 16ea8fa2..e41e3fdd 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -338,20 +338,30 @@ def options_parse() -> argparse.Namespace: args.aseg_name = args.subject_dir / DEFAULT_INPUT_PATHS["aseg_name"] else: print("WARNING: Not providing subject_dir leads to discarding of files with relative paths!") - args.subject_dir = Path("/dev/null/no-subject-dir-set") - - all_paths = ("segmentation", "segmentation_in_orig", "cc_measures", "upright_lta", "orient_volume_lta", "cc_surf", - "softlabels_cc", "softlabels_fn", "softlabels_background", "cc_mid_measures", "thickness_overlay", - "qc_image", "thickness_image", "cc_html") - - # Create parent directories for all output paths - for path_name in all_paths: - path: Path | None = getattr(args, path_name, None) - if isinstance(path, Path) and not args.subject_dir and not path.is_absolute(): - # set path to none in arguments - # FIXME: Should there be a check, if a specific "path_name" is mandatory? - print(f"WARNING: Not writing {path_name}, because --sd and --sid are not specified and {path} is relative.") - setattr(args, path_name, None) + args.subject_dir = None + for arg, path in (("--aseg_name", args.aseg_name), ("--conformed_name", args.conf_name)): + if path is None or not Path(path).is_absolute(): + parser.error( + f"When not passing --sd , arguments of --aseg_name and --conformed_name must be " + f"absolute! But the argument passed to {arg} was {path}, i.e. not absolute." + ) + + all_paths = ("segmentation", "segmentation_in_orig", "cc_measures", "upright_lta", "orient_volume_lta", + "cc_surf", "softlabels_cc", "softlabels_fn", "softlabels_background", "cc_mid_measures", + "thickness_overlay", "qc_image", "thickness_image", "cc_html") + + warnings_paths = [] + # Create parent directories for all output paths + for path_name in all_paths: + path: Path | None = getattr(args, path_name, None) + if isinstance(path, Path) and not args.subject_dir and not path.is_absolute(): + # set path to none in arguments + warnings_paths.append(path_name) + setattr(args, path_name, None) + if warnings_paths: + _warnings_paths = "' '".join(warnings_paths) + print(f"WARNING: Not writing '{_warnings_paths}', because --sd and --sid are not specified and " + f"its paths are relative.") return args @@ -646,7 +656,7 @@ def main( if subdivisions is None: subdivisions = [1 / 6, 1 / 2, 2 / 3, 3 / 4] - subject_dir = Path(subject_dir) + subject_dir = Path("/dev/null/no-subject-dir" if subject_dir is None else subject_dir) logger.info("Starting corpus callosum analysis pipeline") logger.info(f"Input MRI: {conf_name}") diff --git a/CorpusCallosum/localization/inference.py b/CorpusCallosum/localization/inference.py index fd6df1b1..b9b2a7ed 100644 --- a/CorpusCallosum/localization/inference.py +++ b/CorpusCallosum/localization/inference.py @@ -224,13 +224,18 @@ def run_inference_on_slice( Detected PC voxel coordinates with shape (2,) containing its [y,x] positions. """ + if num_iterations < 1: + raise ValueError("localization inference with less than 1 iteration is invalid!") + + pc_coords, ac_coords = center_pt[None], center_pt[None] + crop_left, crop_top = 0, 0 # Run inference - for i in range(num_iterations): + for _ in range(num_iterations): pc_coords, ac_coords, _, (crop_left, crop_top) = run_inference(model, image_slice, center_pt) center_pt = np.mean(np.stack([ac_coords, pc_coords], axis=0), axis=(0, 1)) # average ac and pc coords across sagittal slices - pc_coords = np.mean(pc_coords, axis=0, keepdims=True) - ac_coords = np.mean(ac_coords, axis=0, keepdims=True) + _pc_coords = np.mean(pc_coords, axis=0) + _ac_coords = np.mean(ac_coords, axis=0) if debug_output is not None: import matplotlib.pyplot as plt @@ -245,5 +250,4 @@ def run_inference_on_slice( plt.savefig(debug_output, bbox_inches='tight') plt.close() - - return ac_coords[0], pc_coords[0] + return _ac_coords, _pc_coords diff --git a/CorpusCallosum/shape/mesh.py b/CorpusCallosum/shape/mesh.py index 4e889d59..02fadfe7 100644 --- a/CorpusCallosum/shape/mesh.py +++ b/CorpusCallosum/shape/mesh.py @@ -952,8 +952,8 @@ def smooth_contour(self, contour_idx: int, window_size: int = 5) -> None: ---------- contour_idx : int Index of the contour to smooth. - window_size : int, optional - Size of the smoothing window, by default 5. + window_size : int, default=5 + Size of the smoothing window. Notes ----- @@ -969,6 +969,7 @@ def smooth_contour(self, contour_idx: int, window_size: int = 5) -> None: def plot_cc_contour_with_levelsets( self, contour_idx: int = 0, + #FIXME: levelpaths is not used levelpaths: list | None = None, title: str | None = None, save_path: str | None = None, @@ -982,28 +983,26 @@ def plot_cc_contour_with_levelsets( Parameters ---------- - contour_idx : int, optional + contour_idx : int, default=0 Index of the contour to plot, by default 0. - levelpaths : list or None, optional - List of levelset paths. If None, uses stored levelpaths, by default None. - title : str or None, optional - Title for the plot, by default None. - save_path : str or None, optional - Path to save the plot. If None, displays interactively, by default None. - colorbar : bool, optional - Whether to show the colorbar, by default True. - mode: str, optional - Mode of the plot, by default "p-value". Can be "p-value" or "icc". + levelpaths : list, optional + List of levelset paths. If None, uses stored levelpaths. + title : str, optional + Title for the plot. + save_path : str, optional + Path to save the plot. If None, displays interactively. + colorbar : bool, default=True + Whether to show the colorbar. + mode : {"p-value", "icc"}, default="p-value" + Mode of the plot. + Returns ------- matplotlib.figure.Figure The created figure object. """ - plot_values = np.array(self.thickness_values[contour_idx][~np.isnan(self.thickness_values[contour_idx])])[ - ::-1 - ] - + plot_values = np.array(self.thickness_values[contour_idx][~np.isnan(self.thickness_values[contour_idx])])[::-1] points, trias = make_mesh_from_contour(self.contours[contour_idx], max_volume=0.5, min_angle=25, verbose=False) # make points 3D by adding zero @@ -1077,7 +1076,7 @@ def plot_cc_contour_with_levelsets( # Use griddata to perform smooth interpolation - using 'linear' instead of 'cubic' # and properly formatting the input points grid_values = scipy.interpolate.griddata( - (all_level_points_x, all_level_points_y), all_level_values, (x_grid, y_grid), method="linear", fill_value=0 + (all_level_points_x, all_level_points_y), all_level_values, (x_grid, y_grid), method="linear", fill_value=0, ) # smooth the grid_values @@ -1086,16 +1085,15 @@ def plot_cc_contour_with_levelsets( # Apply the mask to only show values inside the contour masked_values = np.where(mask, grid_values, np.nan) - if mode == "p-value": # Sample colormaps colors1 = plt.cm.binary([0.4] * 128) colors2 = plt.cm.hot(np.linspace(0.8, 0.1, 128)) - - - else: + elif mode == "icc": colors1 = plt.cm.Blues(np.linspace(0, 1, 128)) colors2 = plt.cm.binary([0.4] * 128) + else: + raise ValueError(f"Invalid mode '{mode}'") # Combine the color samples colors = np.vstack((colors2, colors1)) @@ -1113,7 +1111,7 @@ def plot_cc_contour_with_levelsets( # Plot the filled contour with interpolated colors plt.imshow( masked_values, - extent=[x_min - margin, x_max + margin, y_min - margin, y_max + margin], + extent=(x_min - margin, x_max + margin, y_min - margin, y_max + margin), origin="lower", cmap=cmap, alpha=1, @@ -1125,7 +1123,7 @@ def plot_cc_contour_with_levelsets( plt.imshow( masked_values, - extent=[x_min - margin, x_max + margin, y_min - margin, y_max + margin], + extent=(x_min - margin, x_max + margin, y_min - margin, y_max + margin), origin="lower", cmap=cmap, alpha=1, @@ -1136,8 +1134,6 @@ def plot_cc_contour_with_levelsets( transform=transform, ) - - if colorbar: # Add a colorbar cbar = plt.colorbar(aspect=10) @@ -1605,13 +1601,12 @@ def to_fs_coordinates( ----- Mesh coordinates seem to be in ASR (Anterior-Superior-Right) orientation, with the coordinate system origin on *the* midslice. - - The function: - 1. convert from mesh coordinates (LSA and voxel coordinates) to fsaverage voxel coordinates (LIA, origin). - a. Converts coordinates from ASR to LSA orientation. - b. Converts to voxel coordinates using voxel size. - c. Centers LR coordinates and flips SI coordinates. - 2. Applies vox2ras_tkr transformation to get final coordinates. + The function performs the following: + 1. Convert from mesh coordinates (LSA and voxel coordinates) to fsaverage voxel coordinates (LIA, origin). + a. Convert coordinates from ASR to LSA orientation. + b. Convert to voxel coordinates using voxel size. + c. Center LR coordinates and flips SI coordinates. + 2. Apply vox2ras_tkr transformation to get final coordinates. """ # to voxel coordinates diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index 3dcd5454..42b20c4d 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -222,17 +222,21 @@ def recon_cc_surf_measures_multi( logger.info(f"Calculating CC measurements for slice {slice_idx+1}{progress}") cc_measures, contour_in_as_space_and_thickness, endpoint_idxs = _results contour_in_as_space, thickness_values = np.split(contour_in_as_space_and_thickness, (2,), axis=1) - cc_mesh.add_contour(start_slice-slice_idx, contour_in_as_space, thickness_values[:, 0], start_end_idx=endpoint_idxs) + cc_mesh.add_contour(start_slice - slice_idx, contour_in_as_space, thickness_values[:, 0], endpoint_idxs) if cc_measures is None: # this should not happen, but just in case logger.warning(f"Slice index {slice_idx+1}{progress} returned result `None`") slice_cc_measures.append(cc_measures) - if logger.getEffectiveLevel() <= logging.INFO and subject_dir.has_attribute("cc_qc_image"): + if subject_dir.has_attribute("cc_qc_image"): qc_img = subject_dir.filename_by_attribute("cc_qc_image") if logger.getEffectiveLevel() <= logging.DEBUG: - qc_img = (qc_img.parent / f"{qc_img.stem}_slice_{slice_idx}{qc_img.suffix}").with_suffix(".png") + qc_slice_img = (qc_img.parent / f"{qc_img.stem}_slice_{slice_idx}{qc_img.suffix}").with_suffix(".png") + if slice_idx == num_slices // 2: + qc_img = qc_img, qc_slice_img + else: + qc_img = qc_slice_img if logger.getEffectiveLevel() <= logging.DEBUG or slice_idx == num_slices // 2: logger.info(f"Saving segmentation qc image to {qc_img}") @@ -250,7 +254,7 @@ def recon_cc_surf_measures_multi( ac_coords=ac_coords, pc_coords=pc_coords, vox_size=vox_size, - title=f"CC Subsegmentation by {subdivision_method} (Slice {slice_idx})", + title=f"CC Subsegmentation by {subdivision_method} (Slice {slice_idx + 1})", ) ) @@ -545,7 +549,7 @@ def get_unique_contour_points(split_contours: list[tuple[np.ndarray, np.ndarray] def make_subdivision_mask( slice_shape: tuple[int, int], split_contours: list[tuple[np.ndarray, np.ndarray]], - vox_size: tuple[float, float, float] + vox_size: tuple[float, float, float], ) -> np.ndarray: """Create a mask for subdividing the corpus callosum based on split contours. @@ -557,7 +561,7 @@ def make_subdivision_mask( List of contours defining the subdivisions. Each contour is a tuple of x and y coordinates. vox_size : triplet of floats - + The voxel sizes of the image grid in LIA orientation. Returns ------- @@ -574,7 +578,6 @@ def make_subdivision_mask( 4. For each subdivision line: - Tests which points lie to the right of the line. - Updates labels for those points. - """ # unique contour points are the points where sub-division lines were inserted diff --git a/CorpusCallosum/utils/mapping_helpers.py b/CorpusCallosum/utils/mapping_helpers.py index 37483bdf..c36ccd26 100644 --- a/CorpusCallosum/utils/mapping_helpers.py +++ b/CorpusCallosum/utils/mapping_helpers.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Literal import nibabel as nib import numpy as np @@ -7,9 +8,11 @@ from scipy.ndimage import affine_transform from CorpusCallosum.data.constants import CC_LABEL, FORNIX_LABEL -from FastSurferCNN.utils import logging, AffineMatrix4x4 +from FastSurferCNN.utils import AffineMatrix4x4, logging from FastSurferCNN.utils.parallel import thread_executor +Vector3D = np.ndarray[tuple[Literal[3]], np.dtype[float]] + logger = logging.get_logger(__name__) @@ -117,10 +120,10 @@ def apply_transform_to_pt(pts: npt.NDArray[float], T: npt.NDArray[float], inv: b def calc_mapping_to_standard_space( orig: "nib.Nifti1Image", - ac_coords_3d: npt.NDArray[float], - pc_coords_3d: npt.NDArray[float], - orig_fsaverage_vox2vox: npt.NDArray[float], -) -> tuple[npt.NDArray[float], npt.NDArray[float], npt.NDArray[float], npt.NDArray[float], npt.NDArray[float]]: + ac_coords_3d: Vector3D, + pc_coords_3d: Vector3D, + orig_fsaverage_vox2vox: AffineMatrix4x4, +) -> tuple[AffineMatrix4x4, Vector3D, Vector3D, Vector3D, Vector3D]: """Get transformations to map image to standard space. Parameters @@ -131,7 +134,7 @@ def calc_mapping_to_standard_space( AC coordinates in 3D space. pc_coords_3d : np.ndarray PC coordinates in 3D space. - orig_fsaverage_vox2vox : np.ndarray + orig_fsaverage_vox2vox : AffineMatrix4x4 Transformation matrix from original to fsaverage space. Returns @@ -153,50 +156,44 @@ def calc_mapping_to_standard_space( nod_correct_2d = correct_nodding(ac_coords_3d[1:3], pc_coords_3d[1:3]) # convert 2D nodding correction to 3D transformation matrix - nod_correct_3d = np.eye(4) + nod_correct_3d: AffineMatrix4x4 = np.eye(4, dtype=float) nod_correct_3d[1:3, 1:3] = nod_correct_2d[:2, :2] # Copy rotation part to y,z axes # Copy translation part to y,z axes (usually no translation) nod_correct_3d[1:3, 3] = nod_correct_2d[:2, 2] - ac_coords_after_nodding = apply_transform_to_pt( - ac_coords_3d, nod_correct_3d, inv=False + ac_coords_after_nodding: Vector3D = apply_transform_to_pt( + ac_coords_3d, nod_correct_3d, inv=False, ) - pc_coords_after_nodding = apply_transform_to_pt( - pc_coords_3d, nod_correct_3d, inv=False + pc_coords_after_nodding: Vector3D = apply_transform_to_pt( + pc_coords_3d, nod_correct_3d, inv=False, ) - ac_to_center_translation = np.eye(4) + ac_to_center_translation: AffineMatrix4x4 = np.eye(4, dtype=float) ac_to_center_translation[:3, 3] = image_center - ac_coords_after_nodding # correct nodding - ac_coords_standardized = apply_transform_to_pt( - ac_coords_after_nodding, ac_to_center_translation, inv=False + ac_coords_standardized: Vector3D = apply_transform_to_pt( + ac_coords_after_nodding, ac_to_center_translation, inv=False, ) - pc_coords_standardized = apply_transform_to_pt( - pc_coords_after_nodding, ac_to_center_translation, inv=False + pc_coords_standardized: Vector3D = apply_transform_to_pt( + pc_coords_after_nodding, ac_to_center_translation, inv=False, ) - standardized_to_orig_vox2vox = ( + standardized_to_orig_vox2vox: AffineMatrix4x4 = ( np.linalg.inv(orig_fsaverage_vox2vox) @ np.linalg.inv(nod_correct_3d) @ np.linalg.inv(ac_to_center_translation) ) # calculate ac & pc in space of mri input image - ac_coords_orig = apply_transform_to_pt( - ac_coords_standardized, standardized_to_orig_vox2vox, inv=False + ac_coords_orig: Vector3D = apply_transform_to_pt( + ac_coords_standardized, standardized_to_orig_vox2vox, inv=False, ) - pc_coords_orig = apply_transform_to_pt( - pc_coords_standardized, standardized_to_orig_vox2vox, inv=False - ) - - return ( - standardized_to_orig_vox2vox, - ac_coords_standardized, - pc_coords_standardized, - ac_coords_orig, - pc_coords_orig, + pc_coords_orig: Vector3D = apply_transform_to_pt( + pc_coords_standardized, standardized_to_orig_vox2vox, inv=False, ) + #FIXME: incorrect docstring + return standardized_to_orig_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig def apply_transform_to_volume( @@ -217,13 +214,13 @@ def apply_transform_to_volume( vox2vox : np.ndarray Transformation matrix to apply to the data, this is from input-to-output space. affine : AffineMatrix4x4, optional - vox2ras matrix of the output image, only relevant if output_path is given. - header : nib.freesurfer.mghformat.MGHHeader, optional + The vox2ras matrix of the output image, only relevant if output_path is given. + header : nibabel.freesurfer.mghformat.MGHHeader, optional Header for the output image, only relevant if output_path is given, if None will default to orig_image header. output_path : str or Path, optional If output_path is provided, saves the result under this path. output_size : np.ndarray, optional - Size of output volume, uses input size by default (`None`). + Size of output volume, uses input size by default `None`. order : int, default=1 Order of interpolation. diff --git a/CorpusCallosum/utils/visualization.py b/CorpusCallosum/utils/visualization.py index 3b98ab29..fb3742d2 100644 --- a/CorpusCallosum/utils/visualization.py +++ b/CorpusCallosum/utils/visualization.py @@ -130,7 +130,7 @@ def plot_contours( split_contours: list[np.ndarray] | None = None, midline_equidistant: np.ndarray | None = None, levelpaths: list[np.ndarray] | None = None, - output_path: str | Path | None = None, + output_path: str | Path | list[Path] | None = None, ac_coords: np.ndarray | None = None, pc_coords: np.ndarray | None = None, vox_size: tuple[float, float, float] | None = None, @@ -148,7 +148,7 @@ def plot_contours( Midline points at equidistant spacing (ignore midline on None). levelpaths : list[np.ndarray], optional List of level paths for visualization (ignore level paths on None). - output_path : str or Path, optional + output_path : str or Path or list of Paths, optional Path to save the plot (do not save on None). ac_coords : np.ndarray, optional AC coordinates for visualization (ignore AC on None). @@ -180,7 +180,7 @@ def plot_contours( has_first_plot = not (len(_split_contours) == 0 and ac_coords is None and pc_coords is None) num_plots = 1 + int(has_first_plot) - _, ax = plt.subplots(1, num_plots, sharex=True, sharey=True, figsize=(15, 10)) + fig, ax = plt.subplots(1, num_plots, sharex=True, sharey=True, figsize=(15, 10)) # NOTE: For all plots imshow shows y inverted current_plot = 0 @@ -219,5 +219,9 @@ def plot_contours( a.set_xlim(reference_contour[0, :].min() - padding, reference_contour[0, :].max() + padding) a.set_ylim((-reference_contour[1, :]).max() + padding, (-reference_contour[1, :]).min() - padding) - Path(output_path).parent.mkdir(parents=True, exist_ok=True) - plt.savefig(output_path, dpi=300, bbox_inches="tight") + if output_path is None: + return plt.show() + for _output_path in (output_path if isinstance(output_path, (list, tuple)) else [output_path]): + Path(_output_path).parent.mkdir(parents=True, exist_ok=True) + fig.savefig(_output_path, dpi=300, bbox_inches="tight") + return None From 3780abe242ed6a0d64268c46e458f974901b1741 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Mon, 8 Dec 2025 16:04:53 +0100 Subject: [PATCH 43/68] Fix ruff errors --- CorpusCallosum/segmentation/inference.py | 2 +- CorpusCallosum/shape/postprocessing.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CorpusCallosum/segmentation/inference.py b/CorpusCallosum/segmentation/inference.py index a1a6cc5b..d3feaa46 100644 --- a/CorpusCallosum/segmentation/inference.py +++ b/CorpusCallosum/segmentation/inference.py @@ -103,7 +103,7 @@ def run_inference( pc_center : np.ndarray Posterior commissure coordinates. voxel_size : a pair of floats - Voxel size fo inferior/superior and anterior/posterior direction in mm. + Voxel size of inferior/superior and anterior/posterior direction in mm. device : torch.device or None, optional Device to run inference on, by default None. If None, uses the device of the model. diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index 42b20c4d..cd167d96 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -361,7 +361,7 @@ def recon_cc_surf_measure( Dictionary containing measurements if successful. contour_with_thickness : np.ndarray Contour points with thickness information. - endpoint_indices : paor of ints + endpoint_indices : pair of ints Indices of the anterior and posterior endpoints on the contour. Raises From 5024d0d3d0ce99ae41e3625584f8a4b51814f1ae Mon Sep 17 00:00:00 2001 From: ClePol Date: Tue, 9 Dec 2025 11:50:18 +0100 Subject: [PATCH 44/68] updated helptext --- CorpusCallosum/fastsurfer_cc.py | 55 +++++++++++++++++---------------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index e41e3fdd..8cfb2832 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -183,112 +183,113 @@ def _slice_selection(a: str) -> SliceSelection: ) add_arguments(advanced, ["threads"]) advanced.add_argument( - "--upright_volume", + "--segmentation", "--seg", type=path_or_none, - help="Path for upright volume output.", - default=None, + help="Output path for corpus callosum and fornix segmentation output.", + default=Path(DEFAULT_OUTPUT_PATHS["segmentation"]), ) advanced.add_argument( - "--segmentation", "--seg", + "--segmentation_in_orig", type=path_or_none, - help="Path for corpus callosum and fornix segmentation 3D image.", - default=Path(DEFAULT_OUTPUT_PATHS["segmentation"]), + help="Output path for corpus callosum and fornix segmentation output in the input MRI space.", + default=DEFAULT_OUTPUT_PATHS["segmentation_in_orig"], ) advanced.add_argument( "--cc_measures", type=path_or_none, - help="Path for surface-based corpus callosum measures describing shape and volume for each image slice.", + help="Output path for surface-based corpus callosum measures describing shape and volume for each image slice.", default=Path(DEFAULT_OUTPUT_PATHS["cc_measures"]), ) advanced.add_argument( "--cc_mid_measures", type=path_or_none, - help="Path for surface-based corpus callosum measures of the midslice describing CC shape and volume.", + help="Output path for surface-based corpus callosum measures of the midslice describing CC shape and volume.", default=DEFAULT_OUTPUT_PATHS["cc_markers"], ) advanced.add_argument( "--upright_lta", type=path_or_none, - help="Path for upright LTA transform. This makes sure the midplane is at 128 in LR direction, but no nodding " - "correction is applied.", + help="Output path for upright LTA transform. This makes sure the midplane is at 128 in LR direction, " + "but no nodding correction is applied.", default=DEFAULT_OUTPUT_PATHS["upright_lta"], ) advanced.add_argument( - "--orient_volume_lta", + "--upright_volume", type=path_or_none, - help="Path for orientation volume LTA transform. This makes sure the midplane is at 128 in LR direction, and " - "the anterior and posterior commisures are on the coordinate line, standardizing the head orientation.", - default=DEFAULT_OUTPUT_PATHS["orient_volume_lta"], + help="Output path for upright volume (input image with cc_up.lta applied).", + default=None, ) advanced.add_argument( - "--segmentation_in_orig", + "--orient_volume_lta", type=path_or_none, - help="Path for corpus callosum and fornix segmentation in the input MRI space.", - default=DEFAULT_OUTPUT_PATHS["segmentation_in_orig"], + help="Output path for orientation volume LTA transform. This makes sure the midplane is the volume center, " + "the anterior and posterior commisures are on the coordinate line, and the posterior commissure is " + "at the origin - standardizing the head position.", + default=DEFAULT_OUTPUT_PATHS["orient_volume_lta"], ) advanced.add_argument( "--qc_image", type=path_or_none, - help="Path for QC visualization image .", + help="Output path for QC visualization image.", default=DEFAULT_OUTPUT_PATHS["qc_image"], ) advanced.add_argument( "--save_template_dir", type=path_or_none, help="Directory path where to save contours.txt and thickness_values.txt files. These files can be used to " - "visualize the CC shape and volume in 3D.", + "visualize the CC shape and volume with the cc_visualization.py script.", default=None, ) advanced.add_argument( "--thickness_image", type=path_or_none, - help="Path for thickness image.", + help="Output path for thickness image.", default=DEFAULT_OUTPUT_PATHS["thickness_image"], ) advanced.add_argument( "--surf", dest="cc_surf", type=path_or_none, - help="Path for surf file.", + help="Output path for surf file.", default=DEFAULT_OUTPUT_PATHS["cc_surf"], ) advanced.add_argument( "--thickness_overlay", type=path_or_none, - help="Path for corpus callosum thickness overlay file.", + help="Output path for corpus callosum thickness overlay file.", default=DEFAULT_OUTPUT_PATHS["cc_thickness_overlay"], ) advanced.add_argument( "--cc_interactive_html", "--cc_html", dest="cc_html", type=path_or_none, - help="Path to the corpus callosum interactive 3D visualization HTML file.", + help="Output path to the corpus callosum interactive 3D visualization HTML file.", default=DEFAULT_OUTPUT_PATHS["cc_html"], ) advanced.add_argument( "--cc_surf_vtk", type=path_or_none, - help=f"Path for vtk file, showing the CC 3D mesh. Example: {DEFAULT_OUTPUT_PATHS['cc_surf_vtk']}.", + help=f"Output path for vtk file, showing the CC 3D mesh. Example: {DEFAULT_OUTPUT_PATHS['cc_surf_vtk']}.", default=None, ) advanced.add_argument( "--softlabels_cc", type=path_or_none, - help=f"Path for corpus callosum softlabels, which contains the soft labels of each voxel. " + help=f"Output path for corpus callosum softlabels, which contains the soft labels of each voxel. " f"Example: {DEFAULT_OUTPUT_PATHS['softlabels_cc']}.", default=None, ) advanced.add_argument( "--softlabels_fn", type=path_or_none, - help=f"Path for fornix softlabels, which contains the soft labels of each voxel. " + help=f"Output path for fornix softlabels, which contains the soft labels of each voxel. " f"Example: {DEFAULT_OUTPUT_PATHS['softlabels_fn']}.", default=None, ) advanced.add_argument( "--softlabels_background", type=path_or_none, - help=f"Path for background softlabels, which contains the probability of each voxel. " + help=f"Output path for background softlabels, which contains the probability of each voxel. " f"Example: {DEFAULT_OUTPUT_PATHS['softlabels_background']}.", default=None, ) From a4c49c0124c2628ea9736e6697dbc9b1b4bcf209 Mon Sep 17 00:00:00 2001 From: ClePol Date: Tue, 9 Dec 2025 14:46:17 +0100 Subject: [PATCH 45/68] split cc_mesh class into cc_mesh and cc_contour --- CorpusCallosum/shape/contour.py | 497 ++++++++++ CorpusCallosum/shape/mesh.py | 1175 ++++++------------------ CorpusCallosum/shape/postprocessing.py | 40 +- CorpusCallosum/utils/visualization.py | 4 + 4 files changed, 786 insertions(+), 930 deletions(-) create mode 100644 CorpusCallosum/shape/contour.py diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py new file mode 100644 index 00000000..36aa3f80 --- /dev/null +++ b/CorpusCallosum/shape/contour.py @@ -0,0 +1,497 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Literal + +import lapy +import numpy as np +import scipy.interpolate +from scipy.ndimage import gaussian_filter1d + +import FastSurferCNN.utils.logging as logging +from CorpusCallosum.shape.endpoint_heuristic import smooth_contour +from FastSurferCNN.utils.common import suppress_stdout + +logger = logging.get_logger(__name__) + + +class CCContour: + """A class for representing and manipulating corpus callosum (CC) contours. + + This class provides functionality for manipulating and analyzing corpus callosum contours. + + Attributes + ---------- + contour : np.ndarray + Array of shape (N, 2) containing 2D contour points. + endpoint_idxs : tuple[int, int] + Tuple containing start and end indices for the contour. + """ + + + def __init__( + self, + contour: np.ndarray[tuple[Literal["N", 2]], np.dtype[float]], + thickness_values: np.ndarray[tuple[Literal["N"]], np.dtype[float]], + endpoint_idxs: tuple[int, int] | None = None, + resolution: float = 1.0 + ): + """Initialize a CCContour object. + + Parameters + ---------- + contour : np.ndarray + Array of shape (N, 2) containing 2D contour points. + thickness_values : np.ndarray + Array of thickness measurements for each contour point. + endpoint_idxs : tuple[int, int] + Tuple containing start and end indices for the contour. + """ + self.contour = contour + self.thickness_values = thickness_values + # write vertex indices where thickness values are not nan + self.original_thickness_vertices = np.where(~np.isnan(thickness_values))[0] + self.resolution = resolution + + if endpoint_idxs is None: + self.endpoint_idxs = (0, len(contour) // 2) + else: + self.endpoint_idxs = endpoint_idxs + + def smooth_contour(self, window_size: int = 5) -> None: + """Smooth a contour using a moving average filter. + + Parameters + ---------- + contour_idx : int + Index of the contour to smooth. + window_size : int, default=5 + Size of the smoothing window. + + Notes + ----- + Uses smooth_contour from cc_endpoint_heuristic module to: + 1. Extract x and y coordinates. + 2. Apply moving average smoothing. + 3. Update contour with smoothed coordinates. + """ + x, y = self.contour.T + x, y = smooth_contour(x, y, window_size) + self.contour = np.array([x, y]).T + + + def get_contour_edge_lengths(self) -> np.ndarray: + """Get the lengths of the edges of a contour. + + Parameters + ---------- + contour_idx : int + Index of the contour to get the edge lengths for. + + Returns + ------- + np.ndarray + Array of edge lengths for the contour. + + Notes + ----- + Edge lengths are calculated as Euclidean distances between consecutive points + in the contour. + """ + edges = np.diff(self.contour, axis=0) + return np.sqrt(np.sum(edges**2, axis=1)) + + + + def _create_levelpaths( + self, + points: np.ndarray, + trias: np.ndarray, + num_points: int | None = None + ) -> tuple[list[np.ndarray], list[float]]: + """Create level paths for thickness measurements. + + Parameters + ---------- + contour_idx : int + Index of the contour to process + points : np.ndarray + Array of shape (N, 2) containing mesh points + trias : np.ndarray + Array of shape (M, 3) containing triangle indices + num_points : int or None, optional + Number of points to sample along the midline, by default None + + Returns + ------- + tuple[list[np.ndarray], list[float]] + - levelpaths : List of arrays containing level path coordinates + - thickness_values : List of thickness values for each level path + + Notes + ----- + The function: + 1. Creates a triangular mesh from the points + 2. Finds boundary points and endpoints + 3. Solves Poisson equation for level sets + 4. Extracts level paths and interpolates thickness values + """ + + with suppress_stdout(): + cc_tria = lapy.TriaMesh(points, trias) + # extract boundary curve + bdr = np.array(cc_tria.boundary_loops()[0]) + + # find index of endpoints in bdr list + iidx1 = np.where(bdr == self.endpoint_idxs[0])[0][0] + iidx2 = np.where(bdr == self.endpoint_idxs[1])[0][0] + + # create boundary condition (0 at endpoints, -1 on one side, 1 on the other): + if iidx1 > iidx2: + tmp = iidx2 + iidx2 = iidx1 + iidx1 = tmp + dcond = np.ones(bdr.shape) + dcond[iidx1] = 0 + dcond[iidx2] = 0 + dcond[iidx1 + 1 : iidx2] = -1 + + # Extract path + with suppress_stdout(): + fem = lapy.Solver(cc_tria) + vfunc = fem.poisson(0, (bdr, dcond)) + if num_points is not None: + # TODO: do midline stuff + level = 0 + midline_equidistant, midline_length = cc_tria.level_path(vfunc, level, n_points=num_points + 2) + midline_equidistant = midline_equidistant[:, :2] + eval_points = midline_equidistant + else: + eval_points = self.contour + gf = lapy.diffgeo.compute_rotated_f(cc_tria, vfunc) + + # interpolate midline to get levels to evaluate + gf_interp = scipy.interpolate.griddata(cc_tria.v[:, 0:2], gf, eval_points, method="nearest") + + # sort by value + sorting_idx_gf = np.argsort(gf_interp) + gf_interp = gf_interp[sorting_idx_gf] + sorted_thickness_values = self.thickness_values[sorting_idx_gf] + + # get levels to evaluate + # level_length = tria.level_length(gf, gf_interp) + + levelpaths = [] + thickness_values = [] + + for i in range(0, len(eval_points)): + level = gf_interp[i] + # levelpath starts at index zero + if level == 0: + continue + lvlpath, lvlpath_length, tria_idx = cc_tria.level_path(gf, level, get_tria_idx=True) + + levelpaths.append(lvlpath) + thickness_values.append(sorted_thickness_values[i]) + + return levelpaths, thickness_values + + + def fill_thickness_values(self) -> None: + """Interpolate missing thickness values using weighted averaging. + + Notes + ----- + The function: + 1. Processes each contour with missing thickness values. + 2. For each missing value: + - Finds two closest points with known thickness. + - Calculates distances along contour. + - Computes weighted average based on inverse distance. + 3. Updates thickness values in place. + + The weights are calculated as inverse distances to ensure closer + points have more influence on the interpolated value. + + """ + thickness = self.thickness_values + edge_lengths = self.get_contour_edge_lengths() + + # Find indices of points with known thickness + known_idx = np.where(~np.isnan(thickness))[0] + + # For each point with unknown thickness + for j in range(len(thickness)): + if not np.isnan(thickness[j]): + continue + + # Find two closest points with known thickness + distances = np.zeros(len(known_idx)) + for k, idx in enumerate(known_idx): + # Calculate distance along contour by summing edge lengths + if idx > j: + distances[k] = np.sum(edge_lengths[j:idx]) + else: + distances[k] = np.sum(edge_lengths[idx:j]) + + # Get indices of two closest points + closest_indices = known_idx[np.argsort(distances)[:2]] + closest_distances = np.sort(distances)[:2] + + # Calculate weights based on inverse distance + weights = 1.0 / closest_distances + weights = weights / np.sum(weights) + + # Calculate weighted average thickness + thickness[j] = np.sum(weights * thickness[closest_indices]) + + self.thickness_values = thickness + + + + def smooth_thickness_values(self, iterations: int = 1) -> None: + """Smooth the thickness values using a Gaussian filter. + + Parameters + ---------- + iterations : int, optional + Number of smoothing iterations, by default 1. + + Notes + ----- + Applies Gaussian smoothing with sigma=5 to thickness values + for each slice that has measurements. + """ + for i in range(len(self.thickness_values)): + if self.thickness_values[i] is not None: + self.thickness_values[i] = gaussian_filter1d(self.thickness_values[i], sigma=5) + + @staticmethod + def __make_parent_folder(filename: Path | str) -> None: + """Create the parent folder for a file if it doesn't exist. + + Parameters + ---------- + filename : Path, str + Path to the file whose parent folder should be created. + + Notes + ----- + Creates parent directory with parents=False to avoid creating + multiple levels of directories unintentionally. + """ + Path(filename).parent.mkdir(parents=False, exist_ok=True) + + + def save_thickness_measurement_points(self, filename: Path | str) -> None: + """Write the thickness measurement points to a CSV file. + + Parameters + ---------- + filename : Path, str + Path where to save the CSV file. + + Notes + ----- + The function saves measurement points in CSV format with: + - Header: slice_idx,vertex_idx. + - Each measurement point gets its own row. + - Skips slices with no measurement points. + """ + self.__make_parent_folder(filename) + logger.info(f"Saving thickness measurement points to CSV file: {filename}") + with open(filename, "w") as f: + f.write("vertex_idx\n") + for vertex_idx in self.original_thickness_vertices: + f.write(f"{vertex_idx}\n") + + @staticmethod + def _load_thickness_measurement_points(filename: str) -> list[np.ndarray | None]: + """Load thickness measurement points from a CSV file. + + Parameters + ---------- + filename : str + Path to the CSV file containing measurement points. + + Returns + ------- + list[np.ndarray | None] + List of arrays containing vertex indices for each slice where + thickness was measured. None for slices without measurements. + + Notes + ----- + The function: + 1. Reads CSV file with format: slice_idx,vertex_idx + 2. Groups vertex indices by slice index + 3. Creates a list with length matching max slice index + 4. Fills list with vertex indices arrays or None for missing slices + """ + data = np.loadtxt(filename, delimiter=",", skiprows=1) + slice_indices = data[:, 0].astype(int) + vertex_indices = data[:, 1].astype(int) + + # Group values by slice_idx + unique_slices = np.unique(slice_indices) + + # split data into slices + original_thickness_vertices = [None] * (max(unique_slices) + 1) + for slice_idx in unique_slices: + mask = slice_indices == slice_idx + original_thickness_vertices[slice_idx] = vertex_indices[mask] + return original_thickness_vertices + + + + def save_contour(self, output_path: Path | str) -> None: + """Save the contours to a CSV file. + + Parameters + ---------- + output_path : Path, str + Path to save the CSV file. + + Notes + ----- + The function saves contours in CSV format with: + - Header: slice_idx,x,y. + - Special lines indicating new contours with endpoint indices. + - Each point gets its own row with slice index and coordinates. + """ + self.__make_parent_folder(output_path) + logger.info(f"Saving contours to CSV file: {output_path}") + with open(output_path, "w") as f: + f.write("x,y\n") + f.write( + f"New contour, anterior_endpoint_idx={self.endpoint_idxs[0]}, " + f"posterior_endpoint_idx={self.endpoint_idxs[1]}\n" + ) + for point in self.contour: + f.write(f"{point[0]},{point[1]}\n") + + def load_contour(self, input_path: str) -> None: + """Load contour from a CSV file. + + Parameters + ---------- + input_path : str + Path to the CSV file containing the contours. + + Raises + ------ + ValueError + If the file format doesn't match expected structure. + + Notes + ----- + The function: + 1. Reads CSV file with format matching save_contours output. + 2. Processes special lines for endpoint indices. + 3. Reconstructs contours and endpoint indices for each slice. + 4. Converts lists to fixed-size arrays with None padding. + """ + current_points = [] + self.contours = [] + self.start_end_idx = [] + + with open(input_path) as f: + # Skip header + next(f) + + for line in f: + x, y = line.strip().split(",") + current_points.append([float(x), float(y)]) + self.contour = np.array(current_points) + + def save_thickness_values(self, output_path: Path | str) -> None: + """Save thickness values to a CSV file. + + Parameters + ---------- + output_path : Path, str + Path to save the CSV file. + + Notes + ----- + The function saves thickness values in CSV format with: + - Header: thickness. + - Each thickness value gets its own row with slice index. + - Skips slices with no thickness values. + """ + self.__make_parent_folder(output_path) + logger.info(f"Saving thickness data to CSV file: {output_path}") + with open(output_path, "w") as f: + f.write("thickness\n") + for value in self.thickness_values: + f.write(f"{value}\n") + + def load_thickness_values( + self, + input_path: str, + original_thickness_vertices_path: str | None = None + ) -> None: + """Load thickness values from a CSV file. + + Parameters + ---------- + input_path : str + Path to the CSV file containing thickness values. + original_thickness_vertices_path : str or None, optional + Path to a file containing the indices of vertices where thickness + was measured, by default None. + + Raises + ------ + ValueError + If number of thickness values doesn't match measurement points + or if number of slices is inconsistent. + + Notes + ----- + The function: + 1. Reads thickness values from CSV file. + 2. Groups values by slice index. + 3. Optionally associates values with specific vertices. + 4. Handles both full contour and profile measurements. + + + """ + data = np.loadtxt(input_path, delimiter=",", skiprows=1) + values = data[:, 0] + + if original_thickness_vertices_path is None: + # check that the number of thickness values for each slice is equal to the number of points in the contour + assert len(values) == len(self.contour), ( + "Number of thickness values does not match number of points in the contour, maybe you need to " + "provide the measurement points file" + ) + # fill original_thickness_vertices with all indices + self.original_thickness_vertices = np.arange(len(self.contour)) + else: + loaded_original_thickness_vertices = self._load_thickness_measurement_points( + original_thickness_vertices_path + ) + + if len(loaded_original_thickness_vertices) != len(values): + raise ValueError( + "Number of measurement points does not match number of thickness values" + ) + + self.thickness_values = values + logger.error( + f"Tried to load {len(values[~np.isnan(values)])} values, but template has {len(values)} values, " + "supply a correct template to visualize the thickness values" + ) \ No newline at end of file diff --git a/CorpusCallosum/shape/mesh.py b/CorpusCallosum/shape/mesh.py index 02fadfe7..7fc13c6c 100644 --- a/CorpusCallosum/shape/mesh.py +++ b/CorpusCallosum/shape/mesh.py @@ -27,9 +27,8 @@ import FastSurferCNN.utils.logging as logging from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE -from CorpusCallosum.shape.endpoint_heuristic import smooth_contour +from CorpusCallosum.shape.contour import CCContour from CorpusCallosum.shape.thickness import make_mesh_from_contour -from FastSurferCNN.utils.common import suppress_stdout try: from pyrr import Matrix44 @@ -42,6 +41,256 @@ class Matrix44(np.ndarray): logger = logging.get_logger(__name__) + +def _create_cap( + points: np.ndarray, + trias: np.ndarray, + contour: CCContour, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Create a cap mesh for one end of the corpus callosum. + + Parameters + ---------- + points : np.ndarray + Array of shape (N, 2) containing mesh points + trias : np.ndarray + Array of shape (M, 3) containing triangle indices + contour : CCContour + CCContour object to create cap for + + Returns + ------- + tuple[np.ndarray, np.ndarray, np.ndarray] + - level_vertices : Array of vertices for the cap mesh + - level_faces : Array of face indices for the cap mesh + - level_colors : Array of thickness values for each vertex + + Notes + ----- + The function: + 1. Creates level paths using _create_levelpaths + 2. Resamples level paths to fixed number of points + 3. Creates triangles between consecutive level paths + 4. Smooths thickness values for visualization + """ + levelpaths, thickness_values = contour._create_levelpaths(points, trias) + + # Create mesh from level paths + level_vertices = [] + level_faces = [] + level_colors = [] + vertex_counter = 0 + sorted_thickness_values = np.array(thickness_values) + + # smooth thickness values + for _ in range(3): + sorted_thickness_values = gaussian_filter1d(sorted_thickness_values, sigma=5) + + NUM_LEVELPOINTS = 50 + + assert len(sorted_thickness_values) == len(levelpaths) + + # TODO: handle gap between first/last levelpath and contour + for idx, levelpath1 in enumerate(levelpaths): + levelpath1 = lapy.TriaMesh._TriaMesh__iterative_resample_polygon(levelpath1, NUM_LEVELPOINTS) + level_vertices.append(levelpath1) + level_colors.append(np.full((len(levelpath1)), sorted_thickness_values[idx])) + if idx + 1 < len(levelpaths): + levelpath2 = lapy.TriaMesh._TriaMesh__iterative_resample_polygon(levelpaths[idx + 1], NUM_LEVELPOINTS) + + # Create faces between the two paths by connecting vertices + faces_between = [] + i, j = 0, 0 + + while i < len(levelpath1) - 1 and j < len(levelpath2) - 1: + faces_between.append([i, i + 1, len(levelpath1) + j]) + faces_between.append([i + 1, len(levelpath1) + j + 1, len(levelpath1) + j]) + + i += 1 + j += 1 + + while i < len(levelpath1) - 1: + faces_between.append([i, i + 1, len(levelpath1) + j]) + i += 1 + + while j < len(levelpath2) - 1: + faces_between.append([i, len(levelpath1) + j + 1, len(levelpath1) + j]) + j += 1 + + if faces_between: + faces_between = np.array(faces_between) + level_faces.append(faces_between + vertex_counter) + + vertex_counter += len(levelpath1) + + # Convert to numpy arrays + level_vertices = np.vstack(level_vertices) + level_faces = np.vstack(level_faces) + level_colors = np.concatenate(level_colors) + + return level_vertices, level_faces, level_colors + + +def make_triangles_between_contours(contour1: np.ndarray, contour2: np.ndarray) -> np.ndarray: + """Create a triangular mesh between two contours using a robust method. + + Parameters + ---------- + contour1 : np.ndarray + First contour points of shape (N, 2). + contour2 : np.ndarray + Second contour points of shape (M, 2). + + Returns + ------- + np.ndarray + Array of triangle indices of shape (K, 3) where K is the number of triangles. + + Notes + ----- + The function: + 1. Finds closest point on contour2 to first point of contour1 + 2. Creates triangles by connecting corresponding points + 3. Handles contours with different numbers of points + 4. Creates two triangles to form a quad between each pair of points + """ + start_idx_c1 = 0 + # get closest point on contour2 to contour1[0] + start_idx_c2 = np.argmin(np.linalg.norm(contour2 - contour1[0], axis=1)) + + triangles = [] + n1 = len(contour1) + n2 = len(contour2) + + for i in range(n1): + # Current and next indices for contour1 + c1_curr = (start_idx_c1 + i) % n1 + c1_next = (start_idx_c1 + i + 1) % n1 + + # Current and next indices for contour2, offset by n1 to account for vertex stacking + c2_curr = ((start_idx_c2 + i) % n2) + n1 + c2_next = ((start_idx_c2 + i + 1) % n2) + n1 + + # Create two triangles to form a quad between the contours + triangles.append([c1_curr, c2_curr, c1_next]) + triangles.append([c2_curr, c2_next, c1_next]) + + return np.array(triangles) + + + +def create_CC_mesh_from_contours(contours: list[CCContour], + lr_center: float = 0, + closed: bool = False, + smooth: int = 0) -> None: + """Create a surface mesh by triangulating between consecutive contours. + + Parameters + ---------- + contours : list[CCContour] + List of CCContour objects to create mesh from. + lr_center : float, optional + Center position in the left-right axis, by default 0. + closed : bool, optional + Whether to create a closed mesh by adding caps, by default False. + smooth : int, optional + Number of smoothing iterations to apply, by default 0. + + Raises + ------ + Warning + If no valid contours are found. + + Notes + ----- + The function: + 1. Filters out None contours. + 2. Calculates z-coordinates for each slice. + 3. Creates triangles between adjacent contours. + 4. Optionally: + - Creates caps at both ends. + - Applies smoothing. + - Colors caps based on thickness values. + + """ + + # Check that all contours have the same resolution + resolution = contours[0].resolution + for idx, contour in enumerate(contours[1:], start=1): + if not np.isclose(contour.resolution, resolution): + raise ValueError( + f"All contours must have the same resolution. " + f"Expected {resolution}, but contour at index {idx} has {contour.resolution}." + ) + + + # Calculate z coordinates for each slice + z_coordinates = (np.arange(len(contours)) - len(contours) // 2) * contours[0].resolution + lr_center + + # Build vertices list with z-coordinates + vertices = [] + faces = [] + vertex_start_indices = [] # Track starting index for each contour + current_index = 0 + + for i, contour in enumerate(contours): + vertex_start_indices.append(current_index) + vertices.append(np.hstack([contour.contour, np.full((len(contour.contour), 1), z_coordinates[i])])) + + # Check if there's a next valid contour to connect to + if i + 1 < len(contours): + contour2 = contours[i + 1] + faces_between = make_triangles_between_contours(contour.contour, contour2.contour) + faces.append(faces_between + current_index) + + current_index += len(contour.contour) + + vertex_values = np.concatenate([contour.thickness_values for contour in contours]) + + + + if smooth > 0: + tmp_mesh = CCMesh(vertices, faces, vertex_values=vertex_values) + tmp_mesh.smooth_(smooth) + vertices = tmp_mesh.v + faces = tmp_mesh.t + vertex_values = tmp_mesh.mesh_vertex_colors + + if closed: + # Close the mesh by creating caps on both ends + # Left cap (first slice) - use counterclockwise orientation + left_side_points, left_side_trias = make_mesh_from_contour(vertices[: vertex_start_indices[1]][..., :2]) + left_side_points = np.hstack([left_side_points, np.full((len(left_side_points), 1), z_coordinates[0])]) + + # Right cap (last slice) - reverse points for proper orientation + right_side_points, right_side_trias = make_mesh_from_contour(vertices[vertex_start_indices[-1] :][..., :2]) + right_side_points = np.hstack([right_side_points, np.full((len(right_side_points), 1), z_coordinates[-1])]) + + color_sides = True + if color_sides: + left_side_points, left_side_trias, left_side_colors = _create_cap( + left_side_points, left_side_trias, 0 + ) + right_side_points, right_side_trias, right_side_colors = _create_cap( + right_side_points, right_side_trias, len(contours) - 1 + ) + + # reverse right side trias + right_side_trias = right_side_trias[:, ::-1] + + left_side_trias = left_side_trias + current_index + current_index += len(left_side_points) + + right_side_trias = right_side_trias + current_index + current_index += len(right_side_points) + + vertices = [vertices, left_side_points, right_side_points] + faces = [faces, left_side_trias, right_side_trias] + vertex_values = [vertex_values, left_side_colors, right_side_colors] + + return CCMesh(vertices, faces, vertex_values=vertex_values, resolution=resolution) + + class CCMesh(lapy.TriaMesh): """A class for representing and manipulating corpus callosum (CC) meshes. @@ -75,83 +324,24 @@ class CCMesh(lapy.TriaMesh): List of vertex indices where thickness was originally measured. """ - def __init__(self, num_slices: int): + def __init__(self, + vertices: list | np.ndarray, + faces: list | np.ndarray, + vertex_values: list | np.ndarray | None = None, + resolution: float = 1.0): """Initialize a CC_Mesh object. Parameters ---------- - num_slices : int - Number of slices in the corpus callosum mesh - """ - super().__init__(np.zeros((3, 3)), np.zeros((3, 3), dtype=int)) - self.contours: list[np.ndarray | None] = [None] * num_slices - self.thickness_values: list[np.ndarray | None] = [None] * num_slices - self.start_end_idx: list[int | None] = [None] * num_slices - self.ac_coords: np.ndarray | None = None - self.pc_coords: np.ndarray | None = None - self.resolution: tuple[float, float, float] | None = None - # FIXME: v and t do not get properly initialized and all the data in the base class are basically unvalidated - # this class needs to be reworked to either: - # A) properly inherit from TriaMesh, calling super().__init__ with the correct values, or - # B) converting it into a Factory class that then outputs a correct TriaMesh object. - # Currently, there are no real behavior "guarantees" of objects, as the internal state of the object is - # very chaotic and uncontrolled with almost no safeguards (and/or debugging). - self.v = None - self.t = None - self.original_thickness_vertices: list[np.ndarray | None] = [None] * num_slices - - def add_contour( - self, - slice_idx: int, - contour: np.ndarray, - thickness_values: np.ndarray, - start_end_idx: tuple[int, int] | None = None, - ): - """Add a contour and its associated thickness values for a specific slice. - - Parameters - ---------- - slice_idx : int - Index of the slice where the contour should be added. - contour : np.ndarray - Array of shape (N, 2) containing 2D contour points. - thickness_values : np.ndarray - Array of thickness measurements for each contour point. - start_end_idx : tuple[int, int], optional - Tuple containing start and end indices for the contour. - If None, defaults to (0, len(contour)//2). - """ - self.contours[slice_idx] = contour - self.thickness_values[slice_idx] = thickness_values - # write vertex indices where thickness values are not nan - self.original_thickness_vertices[slice_idx] = np.where(~np.isnan(thickness_values))[0] - - if start_end_idx is None: - self.start_end_idx[slice_idx] = (0, len(contour) // 2) - else: - self.start_end_idx[slice_idx] = start_end_idx - - def set_acpc_coords(self, ac_coords: np.ndarray, pc_coords: np.ndarray): - """Set the coordinates of the anterior and posterior commissure. - - Parameters - ---------- - ac_coords : np.ndarray - 3D coordinates of the anterior commissure. - pc_coords : np.ndarray - 3D coordinates of the posterior commissure. - """ - self.ac_coords = ac_coords - self.pc_coords = pc_coords - - def set_resolution(self, resolution: tuple[float, float, float]): - """Set the spatial resolution of the mesh. - - Parameters - ---------- - resolution : triplet of floats - LIA-oriented spatial resolution of the mesh. + vertices : list or numpy.ndarray + List of vertex coordinates or array of shape (N, 3). + faces : list or numpy.ndarray + List of face indices or array of shape (M, 3). + vertex_values : list or numpy.ndarray, optional + Vertex values for each vertex (CC thickness values) """ + super().__init__(np.vstack(vertices), np.vstack(faces)) + self.mesh_vertex_colors = vertex_values self.resolution = resolution def plot_mesh( @@ -159,10 +349,8 @@ def plot_mesh( output_path: Path | str | None = None, colormap: str = "red_to_yellow", thickness_overlay: bool = True, - show_contours: bool = False, show_grid: bool = False, color_range: tuple[float, float] | None = None, - show_mesh_edges: bool = False, legend: str = "", threshold: tuple[float, float] | None = None, ): @@ -343,79 +531,6 @@ def plot_mesh( fig.add_trace(go.Mesh3d(**mesh_args)) - if show_contours: - # Add contour polylines for reference - num_slices = len(self.contours) - - # Calculate z coordinates for each slice - use same calculation as in create_mesh - lr_center = self.v[len(self.v) // 2][2] - z_coordinates = (np.arange(num_slices) - (num_slices // 2)) * self.resolution[0] + lr_center - - for i in range(num_slices): - if self.contours[i] is not None: - # Use slice position for z coordinate - z_coord = z_coordinates[i] - contour = self.contours[i] - - # Create 3D points with fixed z coordinate - v_i = np.hstack([contour, np.full((len(contour), 1), z_coord)]) - - # Close the contour by adding the first point at the end - v_i = np.vstack([v_i, v_i[0]]) - - fig.add_trace( - go.Scatter3d( - x=v_i[:, 0], - y=v_i[:, 1], - z=v_i[:, 2], - mode="lines", - line=dict(color="white", width=2), - opacity=0.5, - hoverinfo="skip", - showlegend=False, - ) - ) - if show_mesh_edges: # show the mesh edges - edge_color = "darkgray" - vertices_in_first_contour = len(self.contours[0]) - - vertices_to_plot_first = np.concatenate([self.v[:vertices_in_first_contour], self.v[None, 0]]) - # Add mesh edges for first 900 vertices as one continuous line - fig.add_trace( - go.Scatter3d( - x=vertices_to_plot_first[:, 0], - y=vertices_to_plot_first[:, 1], - z=vertices_to_plot_first[:, 2], - mode="lines", - line=dict(color=edge_color, width=8), - opacity=1, - hoverinfo="skip", - showlegend=False, - ) - ) - - vertices_in_last_contour = len(self.contours[-1]) - - vertices_before_last_contour = np.sum([len(c) for c in self.contours[:-1]]) - vertices_to_plot_last = np.concatenate( - [ - self.v[vertices_before_last_contour : vertices_before_last_contour + vertices_in_last_contour], - self.v[None, vertices_before_last_contour], - ] - ) - fig.add_trace( - go.Scatter3d( - x=vertices_to_plot_last[:, 0], - y=vertices_to_plot_last[:, 1], - z=vertices_to_plot_last[:, 2], - mode="lines", - line=dict(color=edge_color, width=8), - opacity=1, - hoverinfo="skip", - showlegend=False, - ) - ) - # Calculate axis ranges to maintain equal aspect ratio ranges = [] for i in range(3): @@ -463,428 +578,8 @@ def plot_mesh( plotly_write_html(fig, temp_path, include_plotlyjs="cdn") webbrowser.open(f"file://{temp_path}") - def get_contour_edge_lengths(self, contour_idx: int) -> np.ndarray: - """Get the lengths of the edges of a contour. - - Parameters - ---------- - contour_idx : int - Index of the contour to get the edge lengths for. - - Returns - ------- - np.ndarray - Array of edge lengths for the contour. - - Notes - ----- - Edge lengths are calculated as Euclidean distances between consecutive points - in the contour. - """ - edges = np.diff(self.contours[contour_idx], axis=0) - return np.sqrt(np.sum(edges**2, axis=1)) - - @staticmethod - def make_triangles_between_contours(contour1: np.ndarray, contour2: np.ndarray) -> np.ndarray: - """Create a triangular mesh between two contours using a robust method. - - Parameters - ---------- - contour1 : np.ndarray - First contour points of shape (N, 2). - contour2 : np.ndarray - Second contour points of shape (M, 2). - - Returns - ------- - np.ndarray - Array of triangle indices of shape (K, 3) where K is the number of triangles. - - Notes - ----- - The function: - 1. Finds closest point on contour2 to first point of contour1 - 2. Creates triangles by connecting corresponding points - 3. Handles contours with different numbers of points - 4. Creates two triangles to form a quad between each pair of points - """ - start_idx_c1 = 0 - # get closest point on contour2 to contour1[0] - start_idx_c2 = np.argmin(np.linalg.norm(contour2 - contour1[0], axis=1)) - - triangles = [] - n1 = len(contour1) - n2 = len(contour2) - - for i in range(n1): - # Current and next indices for contour1 - c1_curr = (start_idx_c1 + i) % n1 - c1_next = (start_idx_c1 + i + 1) % n1 - - # Current and next indices for contour2, offset by n1 to account for vertex stacking - c2_curr = ((start_idx_c2 + i) % n2) + n1 - c2_next = ((start_idx_c2 + i + 1) % n2) + n1 - - # Create two triangles to form a quad between the contours - triangles.append([c1_curr, c2_curr, c1_next]) - triangles.append([c2_curr, c2_next, c1_next]) - - return np.array(triangles) - - def _create_levelpaths( - self, - contour_idx: int, - points: np.ndarray, - trias: np.ndarray, - num_points: int | None = None - ) -> tuple[list[np.ndarray], list[float]]: - """Create level paths for thickness measurements. - - Parameters - ---------- - contour_idx : int - Index of the contour to process - points : np.ndarray - Array of shape (N, 2) containing mesh points - trias : np.ndarray - Array of shape (M, 3) containing triangle indices - num_points : int or None, optional - Number of points to sample along the midline, by default None - - Returns - ------- - tuple[list[np.ndarray], list[float]] - - levelpaths : List of arrays containing level path coordinates - - thickness_values : List of thickness values for each level path - - Notes - ----- - The function: - 1. Creates a triangular mesh from the points - 2. Finds boundary points and endpoints - 3. Solves Poisson equation for level sets - 4. Extracts level paths and interpolates thickness values - """ - - with suppress_stdout(): - cc_tria = lapy.TriaMesh(points, trias) - # extract boundary curve - bdr = np.array(cc_tria.boundary_loops()[0]) - - # find index of endpoints in bdr list - iidx1 = np.where(bdr == self.start_end_idx[contour_idx][0])[0][0] - iidx2 = np.where(bdr == self.start_end_idx[contour_idx][1])[0][0] - - # create boundary condition (0 at endpoints, -1 on one side, 1 on the other): - if iidx1 > iidx2: - tmp = iidx2 - iidx2 = iidx1 - iidx1 = tmp - dcond = np.ones(bdr.shape) - dcond[iidx1] = 0 - dcond[iidx2] = 0 - dcond[iidx1 + 1 : iidx2] = -1 - - # Extract path - with suppress_stdout(): - fem = lapy.Solver(cc_tria) - vfunc = fem.poisson(0, (bdr, dcond)) - if num_points is not None: - # TODO: do midline stuff - level = 0 - midline_equidistant, midline_length = cc_tria.level_path(vfunc, level, n_points=num_points + 2) - midline_equidistant = midline_equidistant[:, :2] - eval_points = midline_equidistant - else: - eval_points = self.contours[contour_idx] - gf = lapy.diffgeo.compute_rotated_f(cc_tria, vfunc) - - # interpolate midline to get levels to evaluate - gf_interp = scipy.interpolate.griddata(cc_tria.v[:, 0:2], gf, eval_points, method="nearest") - - # sort by value - sorting_idx_gf = np.argsort(gf_interp) - gf_interp = gf_interp[sorting_idx_gf] - sorted_thickness_values = self.thickness_values[contour_idx][sorting_idx_gf] - - # get levels to evaluate - # level_length = tria.level_length(gf, gf_interp) - - levelpaths = [] - thickness_values = [] - - for i in range(0, len(eval_points)): - level = gf_interp[i] - # levelpath starts at index zero - if level == 0: - continue - lvlpath, lvlpath_length, tria_idx = cc_tria.level_path(gf, level, get_tria_idx=True) - - levelpaths.append(lvlpath) - thickness_values.append(sorted_thickness_values[i]) - - return levelpaths, thickness_values - - def _create_cap( - self, - points: np.ndarray, - trias: np.ndarray, - contour_idx: int - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """Create a cap mesh for one end of the corpus callosum. - - Parameters - ---------- - points : np.ndarray - Array of shape (N, 2) containing mesh points - trias : np.ndarray - Array of shape (M, 3) containing triangle indices - contour_idx : int - Index of the contour to create cap for - - Returns - ------- - tuple[np.ndarray, np.ndarray, np.ndarray] - - level_vertices : Array of vertices for the cap mesh - - level_faces : Array of face indices for the cap mesh - - level_colors : Array of thickness values for each vertex - - Notes - ----- - The function: - 1. Creates level paths using _create_levelpaths - 2. Resamples level paths to fixed number of points - 3. Creates triangles between consecutive level paths - 4. Smooths thickness values for visualization - """ - levelpaths, thickness_values = self._create_levelpaths(contour_idx, points, trias) - - # Create mesh from level paths - level_vertices = [] - level_faces = [] - level_colors = [] - vertex_counter = 0 - sorted_thickness_values = np.array(thickness_values) - - # smooth thickness values - from scipy.ndimage import gaussian_filter1d - - for _ in range(3): - sorted_thickness_values = gaussian_filter1d(sorted_thickness_values, sigma=5) - - NUM_LEVELPOINTS = 50 - - assert len(sorted_thickness_values) == len(levelpaths) - - # TODO: handle gap between first/last levelpath and contour - for idx, levelpath1 in enumerate(levelpaths): - levelpath1 = lapy.TriaMesh._TriaMesh__iterative_resample_polygon(levelpath1, NUM_LEVELPOINTS) - level_vertices.append(levelpath1) - level_colors.append(np.full((len(levelpath1)), sorted_thickness_values[idx])) - if idx + 1 < len(levelpaths): - levelpath2 = lapy.TriaMesh._TriaMesh__iterative_resample_polygon(levelpaths[idx + 1], NUM_LEVELPOINTS) - - # Create faces between the two paths by connecting vertices - faces_between = [] - i, j = 0, 0 - while i < len(levelpath1) - 1 and j < len(levelpath2) - 1: - faces_between.append([i, i + 1, len(levelpath1) + j]) - faces_between.append([i + 1, len(levelpath1) + j + 1, len(levelpath1) + j]) - i += 1 - j += 1 - - while i < len(levelpath1) - 1: - faces_between.append([i, i + 1, len(levelpath1) + j]) - i += 1 - - while j < len(levelpath2) - 1: - faces_between.append([i, len(levelpath1) + j + 1, len(levelpath1) + j]) - j += 1 - - if faces_between: - faces_between = np.array(faces_between) - level_faces.append(faces_between + vertex_counter) - - vertex_counter += len(levelpath1) - - # Convert to numpy arrays - level_vertices = np.vstack(level_vertices) - level_faces = np.vstack(level_faces) - level_colors = np.concatenate(level_colors) - - return level_vertices, level_faces, level_colors - - def create_mesh(self, lr_center: float = 0, closed: bool = False, smooth: int = 0) -> None: - """Create a surface mesh by triangulating between consecutive contours. - - Parameters - ---------- - lr_center : float, optional - Center position in the left-right axis, by default 0. - closed : bool, optional - Whether to create a closed mesh by adding caps, by default False. - smooth : int, optional - Number of smoothing iterations to apply, by default 0. - - Raises - ------ - Warning - If no valid contours are found. - - Notes - ----- - The function: - 1. Filters out None contours. - 2. Calculates z-coordinates for each slice. - 3. Creates triangles between adjacent contours. - 4. Optionally: - - Creates caps at both ends. - - Applies smoothing. - - Colors caps based on thickness values. - - """ - # Filter out None contours and get their indices - valid_contours = [(i, c) for i, c in enumerate(self.contours) if c is not None] - if not valid_contours: - logger.warning("Warning: No valid contours found") - self.v = np.array([]) - self.t = np.array([]) - return - - # Calculate z coordinates for each slice - z_coordinates = (np.arange(len(valid_contours)) - len(valid_contours) // 2) * self.resolution[0] + lr_center - - # Build vertices list with z-coordinates - vertices = [] - faces = [] - vertex_start_indices = [] # Track starting index for each contour - current_index = 0 - - for i, (_, contour) in enumerate(valid_contours): - vertex_start_indices.append(current_index) - vertices.append(np.hstack([contour, np.full((len(contour), 1), z_coordinates[i])])) - - # Check if there's a next valid contour to connect to - if i + 1 < len(valid_contours): - next_idx, contour2 = valid_contours[i + 1] - faces_between = self.make_triangles_between_contours(contour, contour2) - faces.append(faces_between + current_index) - - current_index += len(contour) - - self.set_mesh(vertices, faces, self.thickness_values) - - if smooth > 0: - self.smooth_(smooth) - - if closed: - # Close the mesh by creating caps on both ends - # Left cap (first slice) - use counterclockwise orientation - left_side_points, left_side_trias = make_mesh_from_contour(self.v[: vertex_start_indices[1]][..., :2]) - left_side_points = np.hstack([left_side_points, np.full((len(left_side_points), 1), z_coordinates[0])]) - - # Right cap (last slice) - reverse points for proper orientation - right_side_points, right_side_trias = make_mesh_from_contour(self.v[vertex_start_indices[-1] :][..., :2]) - right_side_points = np.hstack([right_side_points, np.full((len(right_side_points), 1), z_coordinates[-1])]) - - color_sides = True - if color_sides: - left_side_points, left_side_trias, left_side_colors = self._create_cap( - left_side_points, left_side_trias, 0 - ) - right_side_points, right_side_trias, right_side_colors = self._create_cap( - right_side_points, right_side_trias, len(self.contours) - 1 - ) - - # reverse right side trias - right_side_trias = right_side_trias[:, ::-1] - - left_side_trias = left_side_trias + current_index - current_index += len(left_side_points) - - right_side_trias = right_side_trias + current_index - current_index += len(right_side_points) - - self.set_mesh( - [self.v, left_side_points, right_side_points], - [self.t, left_side_trias, right_side_trias], - [self.mesh_vertex_colors, left_side_colors, right_side_colors], - ) - - def fill_thickness_values(self) -> None: - """Interpolate missing thickness values using weighted averaging. - - Notes - ----- - The function: - 1. Processes each contour with missing thickness values. - 2. For each missing value: - - Finds two closest points with known thickness. - - Calculates distances along contour. - - Computes weighted average based on inverse distance. - 3. Updates thickness values in place. - - The weights are calculated as inverse distances to ensure closer - points have more influence on the interpolated value. - - """ - - # For each contour with missing thickness values - for i in range(len(self.contours)): - if self.contours[i] is None or self.thickness_values[i] is None: - continue - - thickness = self.thickness_values[i] - edge_lengths = self.get_contour_edge_lengths(i) - - # Find indices of points with known thickness - known_idx = np.where(~np.isnan(thickness))[0] - - # For each point with unknown thickness - for j in range(len(thickness)): - if not np.isnan(thickness[j]): - continue - - # Find two closest points with known thickness - distances = np.zeros(len(known_idx)) - for k, idx in enumerate(known_idx): - # Calculate distance along contour by summing edge lengths - if idx > j: - distances[k] = np.sum(edge_lengths[j:idx]) - else: - distances[k] = np.sum(edge_lengths[idx:j]) - - # Get indices of two closest points - closest_indices = known_idx[np.argsort(distances)[:2]] - closest_distances = np.sort(distances)[:2] - - # Calculate weights based on inverse distance - weights = 1.0 / closest_distances - weights = weights / np.sum(weights) - - # Calculate weighted average thickness - thickness[j] = np.sum(weights * thickness[closest_indices]) - - self.thickness_values[i] = thickness - - def smooth_thickness_values(self, iterations: int = 1) -> None: - """Smooth the thickness values using a Gaussian filter. - - Parameters - ---------- - iterations : int, optional - Number of smoothing iterations, by default 1. - - Notes - ----- - Applies Gaussian smoothing with sigma=5 to thickness values - for each slice that has measurements. - """ - for i in range(len(self.thickness_values)): - if self.thickness_values[i] is not None: - self.thickness_values[i] = gaussian_filter1d(self.thickness_values[i], sigma=5) def plot_contour(self, slice_idx: int, output_path: str) -> None: """Plot a single contour with thickness values. @@ -945,26 +640,7 @@ def plot_contour(self, slice_idx: int, output_path: str) -> None: plt.tight_layout() plt.savefig(output_path, dpi=300) - def smooth_contour(self, contour_idx: int, window_size: int = 5) -> None: - """Smooth a contour using a moving average filter. - - Parameters - ---------- - contour_idx : int - Index of the contour to smooth. - window_size : int, default=5 - Size of the smoothing window. - Notes - ----- - Uses smooth_contour from cc_endpoint_heuristic module to: - 1. Extract x and y coordinates. - 2. Apply moving average smoothing. - 3. Update contour with smoothed coordinates. - """ - x, y = self.contours[contour_idx].T - x, y = smooth_contour(x, y, window_size) - self.contours[contour_idx] = np.array([x, y]).T def plot_cc_contour_with_levelsets( self, @@ -1161,46 +837,7 @@ def plot_cc_contour_with_levelsets( plt.show() return fig - def set_mesh(self, - vertices: list | np.ndarray, - faces: list | np.ndarray, - thickness_values: list | np.ndarray | None = None) -> None: - """Set the mesh vertices, faces, and optional thickness values. - - Parameters - ---------- - vertices : list or numpy.ndarray - List of vertex coordinates or array of shape (N, 3). - faces : list or numpy.ndarray - List of face indices or array of shape (M, 3). - thickness_values : list or numpy.ndarray, optional - Thickness values for each vertex. - - Returns - ------- - None - The function does not return anything. - """ - # Handle case when there are no faces (single contour) - if not faces: - # For single contour, just store vertices without creating a mesh - vertices_array = np.vstack(vertices) if vertices else np.array([]).reshape(0, 3) - self.v = vertices_array - self.t = np.array([]).reshape(0, 3) - # Initialize fsinfo attribute that lapy expects - self.fsinfo = None - # Skip parent initialization since we have no faces - else: - #FIXME: based on this call and CCMesh.__init__, this whole class probably needs a rework. - super().__init__(np.vstack(vertices), np.vstack(faces)) - - if thickness_values is not None: - # Filter out empty thickness arrays and concatenate - valid_thickness = [tv for tv in thickness_values if tv is not None and len(tv) > 0] - if valid_thickness: - self.mesh_vertex_colors = np.concatenate(valid_thickness) - else: - self.mesh_vertex_colors = np.array([]) + @staticmethod def __create_cc_viewmat() -> "Matrix44": @@ -1352,223 +989,6 @@ def smooth_(self, iterations: int = 1) -> None: super().smooth_(iterations) self.v[:, 2] = z_values - def save_contours(self, output_path: Path | str) -> None: - """Save the contours to a CSV file. - - Parameters - ---------- - output_path : Path, str - Path where to save the CSV file. - - Notes - ----- - The function saves contours in CSV format with: - - Header: slice_idx,x,y. - - Special lines indicating new contours with endpoint indices. - - Each point gets its own row with slice index and coordinates. - """ - logger.info(f"Saving contours to CSV file: {output_path}") - with open(output_path, "w") as f: - # Write header - f.write("slice_idx,x,y\n") - # Write data - for slice_idx, contour in enumerate(self.contours): - if contour is not None: # Skip empty slices - f.write( - f"New contour, anterior_endpoint_idx={self.start_end_idx[slice_idx][0]}, " - f"posterior_endpoint_idx={self.start_end_idx[slice_idx][1]}\n" - ) - for point in contour: - f.write(f"{slice_idx},{point[0]},{point[1]}\n") - - def load_contours(self, input_path: str) -> None: - """Load contours from a CSV file. - - Parameters - ---------- - input_path : str - Path to the CSV file containing the contours. - - Raises - ------ - ValueError - If the file format doesn't match expected structure. - - Notes - ----- - The function: - 1. Reads CSV file with format matching save_contours output. - 2. Processes special lines for endpoint indices. - 3. Reconstructs contours and endpoint indices for each slice. - 4. Converts lists to fixed-size arrays with None padding. - """ - current_points = [] - self.contours = [] - self.start_end_idx = [] - - with open(input_path) as f: - # Skip header - next(f) - - for line in f: - if line.startswith("New contour"): - # If we have points from previous contour, save them - if current_points: - self.contours.append(np.array(current_points)) - current_points = [] - - # Extract anterior and posterior endpoint indices - # Format: "New contour, anterior_endpoint_idx=X,posterior_endpoint_idx=Y" - parts = line.strip().split(",") - anterior_idx = int(parts[1].split("=")[1]) - posterior_idx = int(parts[2].split("=")[1]) - self.start_end_idx.append((anterior_idx, posterior_idx)) - else: - # Parse point data - slice_idx, x, y = line.strip().split(",") - current_points.append([float(x), float(y)]) - - # Don't forget to add the last contour - if current_points: - self.contours.append(np.array(current_points)) - - # Convert lists to fixed-size arrays - max_slices = max(len(self.contours), len(self.start_end_idx)) - self.contours = self.contours + [None] * (max_slices - len(self.contours)) - self.start_end_idx = self.start_end_idx + [None] * (max_slices - len(self.start_end_idx)) - - def save_thickness_values(self, output_path: Path | str) -> None: - """Save thickness values to a CSV file. - - Parameters - ---------- - output_path : Path, str - Path where to save the CSV file. - - Notes - ----- - The function saves thickness values in CSV format with: - - Header: slice_idx,thickness. - - Each thickness value gets its own row with slice index. - - Skips slices with no thickness values. - """ - logger.info(f"Saving thickness data to CSV file: {output_path}") - with open(output_path, "w") as f: - # Write header - f.write("slice_idx,thickness\n") - # Write data - for slice_idx, thickness in enumerate(self.thickness_values): - if thickness is not None: # Skip empty slices - for value in thickness: - f.write(f"{slice_idx},{value}\n") - - def load_thickness_values( - self, - input_path: str, - original_thickness_vertices_path: str | None = None - ) -> None: - """Load thickness values from a CSV file. - - Parameters - ---------- - input_path : str - Path to the CSV file containing thickness values. - original_thickness_vertices_path : str or None, optional - Path to a file containing the indices of vertices where thickness - was measured, by default None. - - Raises - ------ - ValueError - If number of thickness values doesn't match measurement points - or if number of slices is inconsistent. - - Notes - ----- - The function: - 1. Reads thickness values from CSV file. - 2. Groups values by slice index. - 3. Optionally associates values with specific vertices. - 4. Handles both full contour and profile measurements. - - - """ - data = np.loadtxt(input_path, delimiter=",", skiprows=1) - slice_indices = data[:, 0].astype(int) - values = data[:, 1] - - # Group values by slice_idx - unique_slices = np.unique(slice_indices) - - # split data into slices - loaded_thickness_values = [None] * (max(unique_slices) + 1) - for slice_idx in unique_slices: - mask = slice_indices == slice_idx - loaded_thickness_values[slice_idx] = values[mask] - - if original_thickness_vertices_path is None: - # check that the number of thickness values for each slice is equal to the number of points in the contour - for slice_idx, thickness in enumerate(loaded_thickness_values): - if thickness is not None: - assert len(thickness) == len(self.contours[slice_idx]), ( - "Number of thickness values does not match number of points in the contour, maybe you need to " - "provide the measurement points file" - ) - # fill original_thickness_vertices with all indices - self.original_thickness_vertices = [ - np.arange(len(self.contours[slice_idx])) for slice_idx in range(len(self.contours)) - ] - else: - loaded_original_thickness_vertices = self._load_thickness_measurement_points( - original_thickness_vertices_path - ) - - if len(loaded_original_thickness_vertices) != len(loaded_thickness_values): - raise ValueError( - "Number of slices in measurement points does not match number of " - "slices in provided thickness values" - ) - - # check that original_thickness_vertices is equal to number of measurement points for each slice - for slice_idx, vertex_indices in enumerate(loaded_original_thickness_vertices): - if len(vertex_indices) // 2 == len(loaded_thickness_values[slice_idx]) or len( - vertex_indices - ) // 2 == np.sum(~np.isnan(loaded_thickness_values[slice_idx])): - is_thickness_profile = True - elif len(vertex_indices) == len(loaded_thickness_values[slice_idx]) or len(vertex_indices) == np.sum( - ~np.isnan(loaded_thickness_values[slice_idx]) - ): - is_thickness_profile = False - else: - raise ValueError("Number of measurement points does not match number of thickness values") - - # create nan thickness value array for each slice - new_thickness_values = [ - np.full(len(self.contours[slice_idx]), np.nan) for slice_idx in range(len(self.contours)) - ] - for slice_idx, vertex_indices in enumerate(loaded_original_thickness_vertices): - if is_thickness_profile: - new_thickness_values[slice_idx][vertex_indices] = np.concatenate( - [loaded_thickness_values[slice_idx], loaded_thickness_values[slice_idx][::-1]] - ) - else: - try: - new_thickness_values[slice_idx][vertex_indices] = loaded_thickness_values[slice_idx][ - ~np.isnan(loaded_thickness_values[slice_idx])] - except IndexError as err: - logger.error( - f"Tried to load " - f"{loaded_thickness_values[slice_idx][~np.isnan(loaded_thickness_values[slice_idx])]} " - f"values, but template has {new_thickness_values[slice_idx][vertex_indices]} values, " - "supply a correct template to visualize the thickness values" - ) - raise ValueError( - f"Tried to load " - f"{loaded_thickness_values[slice_idx][~np.isnan(loaded_thickness_values[slice_idx])]} " - f"values, but template has {new_thickness_values[slice_idx][vertex_indices]} values, " - "supply a correct template to visualize the thickness values" - ) from err - self.thickness_values = new_thickness_values @staticmethod def __make_parent_folder(filename: Path | str) -> None: @@ -1621,7 +1041,7 @@ def to_fs_coordinates( # all other operations are independent of order of operations (distributive) # v_vox /= vox_size[0] # center LR - v_vox[:, 0] += FSAVERAGE_MIDDLE / self.resolution[0] + v_vox[:, 0] += FSAVERAGE_MIDDLE / self.resolution # flip SI v_vox[:, 1] = -v_vox[:, 1] @@ -1653,7 +1073,7 @@ def write_fssurf(self, filename: Path | str) -> None: self.__make_parent_folder(filename) return super().write_fssurf(filename) - def write_overlay(self, filename: Path | str) -> None: + def write_morph_data(self, filename: Path | str) -> None: """Write the thickness values as a FreeSurfer overlay file. Parameters @@ -1667,64 +1087,3 @@ def write_overlay(self, filename: Path | str) -> None: """ self.__make_parent_folder(filename) return nib.freesurfer.write_morph_data(filename, self.mesh_vertex_colors) - - def save_thickness_measurement_points(self, filename: Path | str) -> None: - """Write the thickness measurement points to a CSV file. - - Parameters - ---------- - filename : Path, str - Path where to save the CSV file. - - Notes - ----- - The function saves measurement points in CSV format with: - - Header: slice_idx,vertex_idx. - - Each measurement point gets its own row. - - Skips slices with no measurement points. - """ - self.__make_parent_folder(filename) - logger.info(f"Saving thickness measurement points to CSV file: {filename}") - with open(filename, "w") as f: - f.write("slice_idx,vertex_idx\n") - for slice_idx, vertex_indices in enumerate(self.original_thickness_vertices): - if vertex_indices is not None: - for vertex_idx in vertex_indices: - f.write(f"{slice_idx},{vertex_idx}\n") - - @staticmethod - def _load_thickness_measurement_points(filename: str) -> list[np.ndarray | None]: - """Load thickness measurement points from a CSV file. - - Parameters - ---------- - filename : str - Path to the CSV file containing measurement points. - - Returns - ------- - list[np.ndarray | None] - List of arrays containing vertex indices for each slice where - thickness was measured. None for slices without measurements. - - Notes - ----- - The function: - 1. Reads CSV file with format: slice_idx,vertex_idx - 2. Groups vertex indices by slice index - 3. Creates a list with length matching max slice index - 4. Fills list with vertex indices arrays or None for missing slices - """ - data = np.loadtxt(filename, delimiter=",", skiprows=1) - slice_indices = data[:, 0].astype(int) - vertex_indices = data[:, 1].astype(int) - - # Group values by slice_idx - unique_slices = np.unique(slice_indices) - - # split data into slices - original_thickness_vertices = [None] * (max(unique_slices) + 1) - for slice_idx in unique_slices: - mask = slice_indices == slice_idx - original_thickness_vertices[slice_idx] = vertex_indices[mask] - return original_thickness_vertices diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index cd167d96..0c6fa035 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -19,8 +19,9 @@ import FastSurferCNN.utils.logging as logging from CorpusCallosum.data.constants import CC_LABEL, FSAVERAGE_MIDDLE, SUBSEGMENT_LABELS +from CorpusCallosum.shape.contour import CCContour from CorpusCallosum.shape.endpoint_heuristic import get_endpoints -from CorpusCallosum.shape.mesh import CCMesh +from CorpusCallosum.shape.mesh import CCMesh, create_CC_mesh_from_contours from CorpusCallosum.shape.metrics import calculate_cc_index from CorpusCallosum.shape.subsegment_contour import ( get_primary_eigenvector, @@ -213,16 +214,14 @@ def recon_cc_surf_measures_multi( per_slice_vox2ras = fsavg_vox2ras @ np.stack(list(map(_gen_fsavg2slice_vox2vox, slices_to_recon)), axis=0) per_slice_recon = process_executor().map(_each_slice, slices_to_recon, per_slice_vox2ras, chunksize=1) - cc_mesh = CCMesh(num_slices=num_slices) - cc_mesh.set_acpc_coords(ac_coords, pc_coords) - cc_mesh.set_resolution(vox_size) + cc_contours = [] for i, (slice_idx, _results) in enumerate(zip(slices_to_recon, per_slice_recon, strict=True)): progress = f" ({i+1} of {num_slices})" if num_slices > 1 else "" logger.info(f"Calculating CC measurements for slice {slice_idx+1}{progress}") cc_measures, contour_in_as_space_and_thickness, endpoint_idxs = _results contour_in_as_space, thickness_values = np.split(contour_in_as_space_and_thickness, (2,), axis=1) - cc_mesh.add_contour(start_slice - slice_idx, contour_in_as_space, thickness_values[:, 0], endpoint_idxs) + cc_contours.append(CCContour(contour_in_as_space, thickness_values[:, 0], endpoint_idxs, resolution=vox_size[0])) if cc_measures is None: # this should not happen, but just in case logger.warning(f"Slice index {slice_idx+1}{progress} returned result `None`") @@ -265,27 +264,24 @@ def recon_cc_surf_measures_multi( template_dir.mkdir(parents=True, exist_ok=True) logger.info("Saving template files (contours.txt, thickness_values.txt, " f"thickness_measurement_points.txt) to {template_dir}") - io_futures.extend([ - thread_executor().submit(cc_mesh.save_contours, template_dir / "contours.txt"), - thread_executor().submit(cc_mesh.save_thickness_values, template_dir / "thickness_values.txt"), - thread_executor().submit( - cc_mesh.save_thickness_measurement_points, - template_dir / "thickness_measurement_points.txt", - ), - ]) + for j in range(len(cc_contours)): + io_futures.extend([ + thread_executor().submit(cc_contours[j].save_contour, template_dir / f"contour_{j}.txt"), + thread_executor().submit(cc_contours[j].save_thickness_values, + template_dir / f"thickness_values_{j}.txt"), + thread_executor().submit(cc_contours[j].save_thickness_measurement_points, + template_dir / f"thickness_measurement_points_{j}.txt"), + ]) mesh_outputs = ("html", "mesh", "thickness_overlay", "surf", "thickness_image") - if len(cc_mesh.contours) > 1 and any(subject_dir.has_attribute(f"cc_{n}") for n in mesh_outputs): - cc_mesh.fill_thickness_values() - cc_mesh.create_mesh() - cc_mesh.smooth_(1) + if len(cc_contours) > 1 and any(subject_dir.has_attribute(f"cc_{n}") for n in mesh_outputs): + for j in range(len(cc_contours)): + cc_contours[j].fill_thickness_values() + cc_mesh = create_CC_mesh_from_contours(cc_contours, smooth=1) if subject_dir.has_attribute("cc_html"): logger.info(f"Saving CC 3D visualization to {subject_dir.filename_by_attribute('cc_html')}") io_futures.append(thread_executor().submit( - cc_mesh.plot_mesh, - output_path=subject_dir.filename_by_attribute("cc_html"), - show_mesh_edges=True, - )) + cc_mesh.plot_mesh,output_path=subject_dir.filename_by_attribute("cc_html"))) if subject_dir.has_attribute("cc_mesh"): vtk_file_path = subject_dir.filename_by_attribute("cc_mesh") @@ -296,7 +292,7 @@ def recon_cc_surf_measures_multi( if subject_dir.has_attribute("cc_thickness_overlay"): overlay_file_path = subject_dir.filename_by_attribute("cc_thickness_overlay") logger.info(f"Saving overlay file to {overlay_file_path}") - io_futures.append(thread_executor().submit(cc_mesh.write_overlay, overlay_file_path)) + io_futures.append(thread_executor().submit(cc_mesh.write_morph_data, overlay_file_path)) if subject_dir.has_attribute("cc_surf"): surf_file_path = subject_dir.filename_by_attribute("cc_surf") diff --git a/CorpusCallosum/utils/visualization.py b/CorpusCallosum/utils/visualization.py index fb3742d2..3a44c634 100644 --- a/CorpusCallosum/utils/visualization.py +++ b/CorpusCallosum/utils/visualization.py @@ -14,6 +14,7 @@ from pathlib import Path +import matplotlib import matplotlib.pyplot as plt import nibabel as nib import numpy as np @@ -168,6 +169,9 @@ def plot_contours( if vox_size is None and None in (split_contours, midline_equidistant, levelpaths): raise ValueError("vox_size must be provided if split_contours, midline_equidistant, or levelpaths are given.") + + if output_path is not None: + matplotlib.use('Agg') # Use non-GUI backend # convert vox_size from LIA to AS vox_size_ras = np.asarray([vox_size[0], vox_size[2], vox_size[1]]) if vox_size is not None else None From 8bab13867dbfc87779f80e451f7a46290c93a86d Mon Sep 17 00:00:00 2001 From: ClePol Date: Tue, 9 Dec 2025 18:50:00 +0100 Subject: [PATCH 46/68] updated cc visualization script with cleaner interface to Mesh, Contour, fsaverage. Fixed multithreading bug. Documentation --- CorpusCallosum/cc_visualization.py | 222 +++++----- CorpusCallosum/data/fsaverage_cc_template.py | 15 +- CorpusCallosum/shape/contour.py | 418 ++++++++++++++++--- CorpusCallosum/shape/mesh.py | 306 +------------- CorpusCallosum/shape/postprocessing.py | 10 +- doc/scripts/cc_visualization.rst | 17 +- doc/scripts/contour.rst | 36 ++ 7 files changed, 538 insertions(+), 486 deletions(-) create mode 100644 doc/scripts/contour.rst diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index 4a18f229..d0b94d2b 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -8,43 +8,37 @@ from CorpusCallosum.data.constants import FSAVERAGE_DATA_PATH from CorpusCallosum.data.fsaverage_cc_template import load_fsaverage_cc_template from CorpusCallosum.data.read_write import load_fsaverage_data -from CorpusCallosum.shape.mesh import CCMesh +from CorpusCallosum.shape.contour import CCContour +from CorpusCallosum.shape.mesh import create_CC_mesh_from_contours def make_parser() -> argparse.ArgumentParser: """Create a command line parser for the visualization pipeline.""" parser = argparse.ArgumentParser(description="Visualize corpus callosum from template files.") parser.add_argument( - "--contours", - type=str, - required=False, - help="Path to contours.txt file if not provided, uses fsaverage template.", - metavar="CONTOURS_PATH", - default=None - ) - parser.add_argument( - "--thickness", - type=str, - required=True, - help="Path to thickness_values.txt file.", - metavar="THICKNESS_VALUES_PATH" - ) - parser.add_argument( - "--measurement_points", + "--template_dir", type=str, required=True, - help="Path to measurement points file containing the original vertex indices where thickness was measured.", + help=( + "Path to a template directory containing per-slice files named " + "thickness_values_.txt, and optionally contour_.txt " + "and thickness_measurement_points_.txt. If contour_.txt " + "and thickness_measurement_points_.txt are not provided, " + "uses fsaverage template." + ), + metavar="TEMPLATE_DIR", + default=None, ) parser.add_argument("--output_dir", type=str, required=True, - help="Directory for output files. Writes: \\\ - cc_mesh.html - Interactive 3D mesh visualization (HTML file) \\\ - midslice_2d.png - 2D midslice visualization of the corpus callosum \\\ - cc_mesh.vtk - VTK mesh file format \\\ - cc_mesh.fssurf - FreeSurfer surface file \\\ - cc_mesh_overlay.curv - FreeSurfer curvature overlay file \\\ - cc_mesh_snap.png - Screenshot/snapshot of the 3D mesh (requires whippersnappy>=1.3.1)", + help="Directory for output files. Writes: " + "cc_mesh.html - Interactive 3D mesh visualization (HTML file) " + "midslice_2d.png - 2D midslice visualization of the corpus callosum " + "cc_mesh.vtk - VTK mesh file format " + "cc_mesh.fssurf - FreeSurfer surface file " + "cc_mesh_overlay.curv - FreeSurfer curvature overlay file " + "cc_mesh_snap.png - Screenshot/snapshot of the 3D mesh (requires whippersnappy>=1.3.1)", metavar="OUTPUT_DIR" ) parser.add_argument( @@ -94,7 +88,8 @@ def make_parser() -> argparse.ArgumentParser: def options_parse() -> argparse.Namespace: """Parse command line arguments for the pipeline.""" - args = make_parser().parse_args() + parser = make_parser() + args = parser.parse_args() # Create output directory if it doesn't exist Path(args.output_dir).mkdir(parents=True, exist_ok=True) @@ -102,10 +97,60 @@ def options_parse() -> argparse.Namespace: return args + + +def load_contours_from_template_dir( + template_dir: Path, resolution: float, smoothing_window: int +) -> list[CCContour]: + """Load all contours and thickness data from a template directory.""" + thickness_files = sorted(template_dir.glob("thickness_values_*.txt")) + if not thickness_files: + raise FileNotFoundError( + f"No thickness files found in template directory {template_dir}. " + "Expected files named thickness_values_.txt and " + "optionally contour_.txt and thickness_measurement_points_.txt." + ) + + fsaverage_contour = None + + contours: list[CCContour] = [] + for thickness_file in thickness_files: + try: + idx = int(thickness_file.stem.split("_")[-1]) + except ValueError: + # skip files that do not follow the expected naming + continue + + contour_file = template_dir / f"contour_{idx}.txt" + + if not contour_file.exists(): + # get length of thickness values + thickness_values = np.loadtxt(thickness_file, dtype=str) + # get the non nan thickness values (excluding header), so we know how many points to sample + num_thickness_values = np.sum(~np.isnan(np.array(thickness_values[1:],dtype=float))) + if fsaverage_contour is None: + fsaverage_contour = load_fsaverage_cc_template() + # create measurment points (points = 2 x levelpaths) accorindg to number of thickness values + fsaverage_contour.create_levelpaths(num_points=num_thickness_values // 2, update_data=True) + current_contour = fsaverage_contour.copy() + current_contour.load_thickness_values(thickness_file) + + else: + # this is kinda ugly - maybe we need to overload the constructor to load the contour and thickness values? + current_contour = CCContour(np.empty((0, 2)), np.empty((0,)), resolution=resolution) + current_contour.load_contour(contour_file) + current_contour.load_thickness_values(thickness_file) + + current_contour.fill_thickness_values() + contours.append(current_contour) + + if not contours: + raise ValueError(f"No valid contours could be loaded from {template_dir}") + return contours + + def main( - contours_path: str | Path | None, - thickness_path: str | Path, - measurement_points_path: str | Path, + template_dir: str | Path, output_dir: str | Path, resolution: float = 1.0, smoothing_window: int = 5, @@ -114,102 +159,55 @@ def main( legend: str | None = None, twoD: bool = False, ) -> Literal[0] | str: - """Main function to visualize corpus callosum from template files. - - This function loads contours and thickness values from template files, - creates a CC_Mesh object, and generates visualizations. - - Parameters - ---------- - contours_path : str or Path or None - Path to contours.txt file. - thickness_path : str or Path - Path to thickness_values.txt file. - measurement_points_path : str or Path - Path to file containing original vertex indices where thickness was measured. - output_dir : str or Path - Directory for output files. - resolution : float, optional - Resolution in mm for the mesh, by default 1.0. - smoothing_window : int, optional - Window size for smoothing the contour, by default 5. - colormap : str, optional - Colormap to use for visualization, by default "red_to_yellow". - Options: - - "red_to_blue": Red -> Orange -> Grey -> Light Blue -> Blue - - "blue_to_red": Blue -> Light Blue -> Grey -> Orange -> Red - - "red_to_yellow": Red -> Yellow -> Light Blue -> Blue - - "yellow_to_red": Yellow -> Light Blue -> Blue -> Red - color_range : tuple[float, float], optional - Fixed range (min, max) for the colorbar, by default None. - legend : str, optional - Legend for the colorbar, by default None. - twoD : bool, optional - If True, generate 2D visualization instead of 3D mesh, by default False. - """ - # Convert paths to Path objects - contours_path = Path(contours_path) if contours_path is not None else None - thickness_path = Path(thickness_path) - measurement_points_path = Path(measurement_points_path) + """Visualize corpus callosum templates in 2D or 3D.""" output_dir = Path(output_dir) - - # Load data and create mesh - cc_mesh = CCMesh(num_slices=1) # Will be resized when loading data + color_range = tuple(color_range) if color_range is not None else None _, _, vox2ras_tkr = load_fsaverage_data(FSAVERAGE_DATA_PATH) - if contours_path is not None: - cc_mesh.load_contours(str(contours_path)) - else: - cc_contour, anterior_endpoint_idx, posterior_endpoint_idx = load_fsaverage_cc_template() - cc_mesh.contours[0] = np.stack(cc_contour).T - cc_mesh.start_end_idx[0] = [anterior_endpoint_idx, posterior_endpoint_idx] - - cc_mesh.load_thickness_values(str(thickness_path), str(measurement_points_path)) - cc_mesh.set_resolution(resolution) + contours = load_contours_from_template_dir( + Path(template_dir), resolution=resolution, smoothing_window=smoothing_window + ) + # 2D visualization + mid_contour = contours[len(contours) // 2] if twoD: - # cc_mesh.smooth_contour(contour_idx=0, window_size=5) - cc_mesh.plot_cc_contour_with_levelsets( - contour_idx=0, levelpaths=None, title=None, save_path=str(output_dir / "cc_thickness_2d.png"), colorbar=True - ) - else: - cc_mesh.fill_thickness_values() - # Create and process mesh - cc_mesh.create_mesh(smooth=smoothing_window, closed=False) - - # Generate visualizations - cc_mesh.plot_mesh( - colormap=colormap, - color_range=color_range, - thickness_overlay=True, - show_contours=False, - show_mesh_edges=True, - legend=legend, + mid_contour.plot_cc_contour_with_levelsets( + title=None, + save_path=str(output_dir / "cc_thickness_2d.png"), + colorbar=True, ) - cc_mesh.plot_mesh(str(output_dir / "cc_mesh.html"), thickness_overlay=True) + return 0 - cc_mesh.plot_cc_contour_with_levelsets( - contour_idx=len(cc_mesh.contours) // 2, save_path=str(output_dir / "midslice_2d.png") - ) + # 3D visualization + cc_mesh = create_CC_mesh_from_contours(contours, smooth=0) - cc_mesh.to_fs_coordinates(vox_size=[resolution, resolution, resolution], vox2ras_tkr=vox2ras_tkr) - cc_mesh.write_vtk(str(output_dir / "cc_mesh.vtk")) - cc_mesh.write_fssurf(str(output_dir / "cc_mesh.fssurf")) - cc_mesh.write_overlay(str(output_dir / "cc_mesh_overlay.curv")) - try: - cc_mesh.snap_cc_picture(str(output_dir / "cc_mesh_snap.png")) - except RuntimeError: - return ("The cc_visualization script requires whippersnappy>=1.3.1 to makes screenshots, install with " - "`pip install whippersnappy>=1.3.1` !") + plot_kwargs = dict( + colormap=colormap, + color_range=color_range, + thickness_overlay=True, + legend=legend or "", + ) + cc_mesh.plot_mesh(**plot_kwargs) + cc_mesh.plot_mesh(output_path=str(output_dir / "cc_mesh.html"), **plot_kwargs) + + mid_contour.plot_cc_contour_with_levelsets(save_path=str(output_dir / "midslice_2d.png")) + + cc_mesh.to_fs_coordinates(vox2ras_tkr=vox2ras_tkr) + cc_mesh.write_vtk(str(output_dir / "cc_mesh.vtk")) + cc_mesh.write_fssurf(str(output_dir / "cc_mesh.fssurf")) + cc_mesh.write_morph_data(str(output_dir / "cc_mesh_overlay.curv")) + try: + cc_mesh.snap_cc_picture(str(output_dir / "cc_mesh_snap.png")) + except RuntimeError: + return ("The cc_visualization script requires whippersnappy>=1.3.1 to makes screenshots, install with " + "`pip install whippersnappy>=1.3.1` !") return 0 if __name__ == "__main__": - options = make_parser().parse_args() + options = options_parse() sys.exit(main( - contours_path=options.contours, - thickness_path=options.thickness, - measurement_points_path=options.measurement_points, + template_dir=options.template_dir, output_dir=options.output_dir, resolution=options.resolution, smoothing_window=options.smoothing_window, diff --git a/CorpusCallosum/data/fsaverage_cc_template.py b/CorpusCallosum/data/fsaverage_cc_template.py index e307dedd..55d11bc2 100644 --- a/CorpusCallosum/data/fsaverage_cc_template.py +++ b/CorpusCallosum/data/fsaverage_cc_template.py @@ -20,6 +20,7 @@ from scipy import ndimage from CorpusCallosum.data import constants +from CorpusCallosum.shape.contour import CCContour from CorpusCallosum.shape.postprocessing import recon_cc_surf_measure from FastSurferCNN.utils.brainvolstats import mask_in_array @@ -132,12 +133,11 @@ def load_fsaverage_cc_template() -> tuple[ contour_smoothing=5, vox_size=(1., 1., 1.), # fsaverage is in 1mm isotropic ) - outside_contour = contour_with_thickness[:2].T - + outside_contour = contour_with_thickness[:,:2].T # make sure the CC stays in shape despite smoothing by moving endpoints outwards - outside_contour[0][anterior_endpoint_idx] -= 55 - outside_contour[0][posterior_endpoint_idx] += 30 + outside_contour[0,anterior_endpoint_idx] -= 55 + outside_contour[0,posterior_endpoint_idx] += 30 # Apply smoothing to the outside contour outside_contour_smoothed = smooth_contour(outside_contour, window_size=11) @@ -145,5 +145,10 @@ def load_fsaverage_cc_template() -> tuple[ outside_contour_smoothed = smooth_contour(outside_contour_smoothed, window_size=30) outside_contour = outside_contour_smoothed + fsaverage_contour = CCContour(np.array(outside_contour).T, + np.zeros(len(outside_contour[0])), + endpoint_idxs=(anterior_endpoint_idx, posterior_endpoint_idx), + resolution=1.0) + - return outside_contour, anterior_endpoint_idx, posterior_endpoint_idx + return fsaverage_contour diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py index 36aa3f80..aad028b5 100644 --- a/CorpusCallosum/shape/contour.py +++ b/CorpusCallosum/shape/contour.py @@ -12,16 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from pathlib import Path from typing import Literal import lapy +import matplotlib +import matplotlib.pyplot as plt import numpy as np import scipy.interpolate from scipy.ndimage import gaussian_filter1d import FastSurferCNN.utils.logging as logging from CorpusCallosum.shape.endpoint_heuristic import smooth_contour +from CorpusCallosum.shape.thickness import cc_thickness, make_mesh_from_contour from FastSurferCNN.utils.common import suppress_stdout logger = logging.get_logger(__name__) @@ -60,7 +64,11 @@ def __init__( Tuple containing start and end indices for the contour. """ self.contour = contour + if self.contour.shape[1] != 2: + raise ValueError(f"Contour must be a 2D array, but is {self.contour.shape}") self.thickness_values = thickness_values + if self.contour.shape[0] != len(thickness_values): + raise ValueError(f"Number of contour points ({self.contour.shape[0]}) does not match number of thickness values ({len(thickness_values)})") # write vertex indices where thickness values are not nan self.original_thickness_vertices = np.where(~np.isnan(thickness_values))[0] self.resolution = resolution @@ -91,6 +99,11 @@ def smooth_contour(self, window_size: int = 5) -> None: x, y = smooth_contour(x, y, window_size) self.contour = np.array([x, y]).T + def copy(self) -> "CCContour": + """Copy the contour. + """ + return CCContour(self.contour.copy(), self.thickness_values.copy(), self.endpoint_idxs, self.resolution) + def get_contour_edge_lengths(self) -> np.ndarray: """Get the lengths of the edges of a contour. @@ -113,7 +126,45 @@ def get_contour_edge_lengths(self) -> np.ndarray: edges = np.diff(self.contour, axis=0) return np.sqrt(np.sum(edges**2, axis=1)) + + def create_levelpaths(self, + num_points: int, + update_data: bool = True + ) -> tuple[list[np.ndarray], list[float]]: + + midline_len, thickness, curvature, midline_equi, \ + levelpaths, contour_with_thickness, endpoint_idxs = cc_thickness( + self.contour, + self.endpoint_idxs, + n_points=num_points, + ) + + if update_data: + self.contour = contour_with_thickness[:, :2] + self.thickness_values = contour_with_thickness[:,2] + self.original_thickness_vertices = np.where(~np.isnan(thickness))[0] + self.endpoint_idxs = endpoint_idxs + return levelpaths, thickness + + def set_thickness_values(self, thickness_values: np.ndarray, use_measurement_points: bool = False) -> None: + """Set the thickness values for the contour. + This is useful to update the thickness values for specific plots. + + Parameters + ---------- + thickness_values : np.ndarray + Array of thickness values for the contour. + use_measurement_points : bool, optional + Whether to use the measurement points to set the thickness values, by default False. + """ + if use_measurement_points: + assert len(thickness_values) == len(self.original_thickness_vertices), "Number of thickness values does not match number of measurement points" + self.thickness_values = np.full(len(self.contour), np.nan) + self.thickness_values[self.original_thickness_vertices] = thickness_values + else: + assert len(thickness_values) == len(self.contour), "Number of thickness values does not match number of points in the contour" + self.thickness_values = thickness_values def _create_levelpaths( self, @@ -232,6 +283,15 @@ def fill_thickness_values(self) -> None: # Find indices of points with known thickness known_idx = np.where(~np.isnan(thickness))[0] + if len(known_idx) == 0: + logger.warning("No known thickness values; skipping interpolation") + return + if len(known_idx) == 1: + logger.warning("Only one known thickness value; skipping interpolation") + thickness[np.isnan(thickness)] = thickness[known_idx[0]] + self.thickness_values = thickness + return + # For each point with unknown thickness for j in range(len(thickness)): if not np.isnan(thickness[j]): @@ -278,47 +338,274 @@ def smooth_thickness_values(self, iterations: int = 1) -> None: if self.thickness_values[i] is not None: self.thickness_values[i] = gaussian_filter1d(self.thickness_values[i], sigma=5) - @staticmethod - def __make_parent_folder(filename: Path | str) -> None: - """Create the parent folder for a file if it doesn't exist. + + def plot_contour(self, output_path: str) -> None: + """Plot a single contour with thickness values. Parameters ---------- - filename : Path, str - Path to the file whose parent folder should be created. + output_path : str + Path where to save the plot. Notes ----- - Creates parent directory with parents=False to avoid creating - multiple levels of directories unintentionally. + Creates a 2D visualization with: + - Points colored by thickness values. + - Gray points for missing thickness values. + - Connected contour line. + - Grid, labels, and legend. """ - Path(filename).parent.mkdir(parents=False, exist_ok=True) - + self.__make_parent_folder(output_path) + + contour = self.contour - def save_thickness_measurement_points(self, filename: Path | str) -> None: - """Write the thickness measurement points to a CSV file. + plt.figure(figsize=(10, 10)) + # Get thickness values for this slice + thickness = self.thickness_values + + # Plot points with colors based on thickness + for i in range(len(contour)): + if np.isnan(thickness[i]): + plt.plot(contour[i, 0], contour[i, 1], "o", color="gray", markersize=1) + else: + # Map thickness to color from red to yellow + plt.plot( + contour[i, 0], + contour[i, 1], + "o", + color=plt.cm.YlOrRd(thickness[i] / np.nanmax(thickness)), + markersize=1, + ) + + # Connect points with lines + plt.plot(contour[:, 0], contour[:, 1], "-", color="black", alpha=0.3, label="Contour") + plt.axis("equal") + plt.xlabel("X") + plt.ylabel("Y") + plt.title("CC contour") + plt.legend() + plt.grid(True) + plt.tight_layout() + plt.savefig(output_path, dpi=300) + + + + def plot_cc_contour_with_levelsets( + self, + title: str | None = None, + save_path: str | None = None, + colorbar: bool = True, + mode: str = "p-value", + ) -> matplotlib.figure.Figure: + """Plot a contour with levelset visualization. + + Creates a visualization of a contour with interpolated levelsets, useful for + analyzing the thickness distribution across the corpus callosum. + + Parameters + ---------- + contour_idx : int, default=0 + Index of the contour to plot, by default 0. + levelpaths : list, optional + List of levelset paths. If None, uses stored levelpaths. + title : str, optional + Title for the plot. + save_path : str, optional + Path to save the plot. If None, displays interactively. + colorbar : bool, default=True + Whether to show the colorbar. + mode : {"p-value", "icc"}, default="p-value" + Mode of the plot. + + Returns + ------- + matplotlib.figure.Figure + The created figure object. + """ + + plot_values = np.array(self.thickness_values[~np.isnan(self.thickness_values)])[::-1] + points, trias = make_mesh_from_contour(self.contour, max_volume=0.5, min_angle=25, verbose=False) + + # make points 3D by adding zero + points = np.column_stack([points, np.zeros(len(points))]) + + levelpaths, _ = self.create_levelpaths(num_points=len(plot_values)-1, update_data=False) + #levelpaths, _ = self._create_levelpaths(points, trias, num_points=len(plot_values)-2) + + outside_contour = self.contour.T + + # Create a grid of points covering the contour area with higher resolution + x_min, x_max = np.min(outside_contour[0]), np.max(outside_contour[0]) + y_min, y_max = np.min(outside_contour[1]), np.max(outside_contour[1]) + margin = 1 + resolution = 0.05 # Higher resolution for smoother interpolation + x_grid, y_grid = np.meshgrid( + np.arange(x_min - margin, x_max + margin, resolution), np.arange(y_min - margin, y_max + margin, resolution) + ) + + # Create a path from the outside contour + contour_path = matplotlib.path.Path(np.column_stack([outside_contour[0], outside_contour[1]])) + + # Check which points are inside the contour + points = np.column_stack([x_grid.flatten(), y_grid.flatten()]) + mask = contour_path.contains_points(points).reshape(x_grid.shape) + + # Collect all levelpath points and their corresponding values + # Extend each levelpath at both ends to improve extrapolation + all_level_points_x = [] + all_level_points_y = [] + all_level_values = [] + + for i, path in enumerate(levelpaths): + + # add third dimension to path + path = np.column_stack([path, np.zeros(len(path))]) + + + if len(path) == 1: + all_level_points_x.append(path[0][0]) + all_level_points_y.append(path[0][1]) + all_level_values.append(plot_values[i]) + continue + + # make levelpath + path = lapy.TriaMesh._TriaMesh__resample_polygon(path, 1000) + + # Extend at the beginning: add point in direction opposite to first segment + first_segment = path[1] - path[0] + # standardize length of first segment + first_segment = first_segment / np.linalg.norm(first_segment) * 10 + extension_start = path[0] - first_segment + all_level_points_x.append(extension_start[0]) + all_level_points_y.append(extension_start[1]) + all_level_values.append(plot_values[i]) + + # Add original path points + for point in path: + all_level_points_x.append(point[0]) + all_level_points_y.append(point[1]) + all_level_values.append(plot_values[i]) + + # Extend at the end: add point in direction of last segment + last_segment = path[-1] - path[-2] + # standardize length of last segment + last_segment = last_segment / np.linalg.norm(last_segment) * 10 + extension_end = path[-1] + last_segment + all_level_points_x.append(extension_end[0]) + all_level_points_y.append(extension_end[1]) + all_level_values.append(plot_values[i]) + + # Convert to numpy arrays + all_level_points_x = np.array(all_level_points_x) + all_level_points_y = np.array(all_level_points_y) + all_level_values = np.array(all_level_values) + + # Use griddata to perform smooth interpolation - using 'linear' instead of 'cubic' + # and properly formatting the input points + grid_values = scipy.interpolate.griddata( + (all_level_points_x, all_level_points_y), all_level_values, (x_grid, y_grid), method="linear", fill_value=0, + ) + + # smooth the grid_values + grid_values = scipy.ndimage.gaussian_filter(grid_values, sigma=5, radius=5) + + # Apply the mask to only show values inside the contour + masked_values = np.where(mask, grid_values, np.nan) + + if mode == "p-value": + # Sample colormaps + colors1 = plt.cm.binary([0.4] * 128) + colors2 = plt.cm.hot(np.linspace(0.8, 0.1, 128)) + elif mode == "icc": + colors1 = plt.cm.Blues(np.linspace(0, 1, 128)) + colors2 = plt.cm.binary([0.4] * 128) + else: + raise ValueError(f"Invalid mode '{mode}'") + + # Combine the color samples + colors = np.vstack((colors2, colors1)) + + # Create a new colormap + cmap = matplotlib.colors.LinearSegmentedColormap.from_list("my_colormap", colors) + + # Plot CC contour with levelsets + fig = plt.figure(figsize=(10, 3)) + # Apply a 10-degree rotation to the entire plot + base = plt.gca().transData + transform = matplotlib.transforms.Affine2D().rotate_deg(10) + transform = transform + base + + # Plot the filled contour with interpolated colors + plt.imshow( + masked_values, + extent=(x_min - margin, x_max + margin, y_min - margin, y_max + margin), + origin="lower", + cmap=cmap, + alpha=1, + interpolation="bilinear", + vmin=0, + vmax=0.10 if mode == "p-value" else 1, + transform=transform, + ) + + plt.imshow( + masked_values, + extent=(x_min - margin, x_max + margin, y_min - margin, y_max + margin), + origin="lower", + cmap=cmap, + alpha=1, + interpolation="bilinear", + vmin=0, + vmax=0.10 if mode == "p-value" else 1, + # norm=LogNorm(vmin=1e-3, vmax=0.1), # Set minimum to avoid log(0) + transform=transform, + ) + + if colorbar: + # Add a colorbar + cbar = plt.colorbar(aspect=10) + if mode == "p-value": + cbar.ax.set_ylim(0.001, 0.054) + cbar.ax.set_yticks([0.0, 0.01, 0.02, 0.03, 0.04, 0.05]) + cbar.set_label("p-value (log scale)") + elif mode == "icc": + cbar.ax.set_ylim(0, 1) + cbar.ax.set_yticks([0, 0.25, 0.5, 0.75, 1]) + cbar.ax.set_label("Intraclass correlation coefficient") + + # Plot the outside contour on top for clear boundary + plt.plot(outside_contour[0], outside_contour[1], "k-", linewidth=2, label="CC Contour", transform=transform) + + plt.axis("equal") + plt.title(title, fontsize=14, fontweight="bold") + # plt.legend(loc='best') + plt.gca().invert_xaxis() + plt.axis("off") + if save_path is not None: + self.__make_parent_folder(save_path) + plt.savefig(save_path, dpi=300) + else: + plt.show() + return fig + + @staticmethod + def __make_parent_folder(filename: Path | str) -> None: + """Create the parent folder for a file if it doesn't exist. Parameters ---------- filename : Path, str - Path where to save the CSV file. + Path to the file whose parent folder should be created. Notes ----- - The function saves measurement points in CSV format with: - - Header: slice_idx,vertex_idx. - - Each measurement point gets its own row. - - Skips slices with no measurement points. + Creates parent directory with parents=False to avoid creating + multiple levels of directories unintentionally. """ - self.__make_parent_folder(filename) - logger.info(f"Saving thickness measurement points to CSV file: {filename}") - with open(filename, "w") as f: - f.write("vertex_idx\n") - for vertex_idx in self.original_thickness_vertices: - f.write(f"{vertex_idx}\n") + Path(filename).parent.mkdir(parents=False, exist_ok=True) @staticmethod - def _load_thickness_measurement_points(filename: str) -> list[np.ndarray | None]: + def _load_thickness_measurement_points(filename: str) -> np.ndarray: """Load thickness measurement points from a CSV file. Parameters @@ -328,31 +615,25 @@ def _load_thickness_measurement_points(filename: str) -> list[np.ndarray | None] Returns ------- - list[np.ndarray | None] - List of arrays containing vertex indices for each slice where - thickness was measured. None for slices without measurements. + np.ndarray + Array containing vertex indices where thickness was measured. Notes ----- The function: - 1. Reads CSV file with format: slice_idx,vertex_idx - 2. Groups vertex indices by slice index - 3. Creates a list with length matching max slice index - 4. Fills list with vertex indices arrays or None for missing slices + 1. Reads CSV file with format: slice_idx,vertex_idx (legacy) or a single + vertex_idx column (current output format). + 2. Returns a flat array of vertex indices. """ data = np.loadtxt(filename, delimiter=",", skiprows=1) - slice_indices = data[:, 0].astype(int) - vertex_indices = data[:, 1].astype(int) - - # Group values by slice_idx - unique_slices = np.unique(slice_indices) - - # split data into slices - original_thickness_vertices = [None] * (max(unique_slices) + 1) - for slice_idx in unique_slices: - mask = slice_indices == slice_idx - original_thickness_vertices[slice_idx] = vertex_indices[mask] - return original_thickness_vertices + # handle scalar files (single measurement point) + if data.ndim == 0: + return np.array([int(data)]) + if data.ndim == 1: + return data.astype(int) + if data.shape[1] != 1: + raise ValueError("Thickness measurement points file must contain a single vertex_idx column.") + return data[:, 0].astype(int) @@ -374,11 +655,12 @@ def save_contour(self, output_path: Path | str) -> None: self.__make_parent_folder(output_path) logger.info(f"Saving contours to CSV file: {output_path}") with open(output_path, "w") as f: - f.write("x,y\n") + f.write( f"New contour, anterior_endpoint_idx={self.endpoint_idxs[0]}, " f"posterior_endpoint_idx={self.endpoint_idxs[1]}\n" ) + f.write("x,y\n") for point in self.contour: f.write(f"{point[0]},{point[1]}\n") @@ -405,10 +687,20 @@ def load_contour(self, input_path: str) -> None: """ current_points = [] self.contours = [] - self.start_end_idx = [] + self.endpoint_idxs = [] with open(input_path) as f: - # Skip header + header = next(f).strip() + # Parse endpoint indices from header + anterior_match = re.search(r'anterior_endpoint_idx=(\d+)', header) + posterior_match = re.search(r'posterior_endpoint_idx=(\d+)', header) + assert anterior_match and posterior_match, "Header does not contain endpoint indices" + + anterior_idx = int(anterior_match.group(1)) + posterior_idx = int(posterior_match.group(1)) + self.endpoint_idxs = (anterior_idx, posterior_idx) + + # Skip column names next(f) for line in f: @@ -441,7 +733,6 @@ def save_thickness_values(self, output_path: Path | str) -> None: def load_thickness_values( self, input_path: str, - original_thickness_vertices_path: str | None = None ) -> None: """Load thickness values from a CSV file. @@ -470,28 +761,23 @@ def load_thickness_values( """ data = np.loadtxt(input_path, delimiter=",", skiprows=1) - values = data[:, 0] - - if original_thickness_vertices_path is None: - # check that the number of thickness values for each slice is equal to the number of points in the contour - assert len(values) == len(self.contour), ( - "Number of thickness values does not match number of points in the contour, maybe you need to " - "provide the measurement points file" - ) - # fill original_thickness_vertices with all indices - self.original_thickness_vertices = np.arange(len(self.contour)) + if data.ndim == 0: + values = np.array([float(data)]) + elif data.ndim == 1: + values = data.astype(float) else: - loaded_original_thickness_vertices = self._load_thickness_measurement_points( - original_thickness_vertices_path - ) + raise ValueError("Thickness values file must contain a single column") - if len(loaded_original_thickness_vertices) != len(values): + if len(values) != len(self.contour): + if np.sum(~np.isnan(values)) == len(self.original_thickness_vertices): + new_values = np.full(len(self.contour), np.nan) + new_values[self.original_thickness_vertices] = values[~np.isnan(values)] + else: raise ValueError( - "Number of measurement points does not match number of thickness values" + f"Number of thickness values {len(values)} does not match number of points in the " + f"contour {len(self.contour)} and current number of measururement points " + f"{len(self.original_thickness_vertices)} does not match the number of set thickness values " + f"{np.sum(~np.isnan(values))}." ) - self.thickness_values = values - logger.error( - f"Tried to load {len(values[~np.isnan(values)])} values, but template has {len(values)} values, " - "supply a correct template to visualize the thickness values" - ) \ No newline at end of file + self.thickness_values = new_values \ No newline at end of file diff --git a/CorpusCallosum/shape/mesh.py b/CorpusCallosum/shape/mesh.py index 7fc13c6c..0a4ec187 100644 --- a/CorpusCallosum/shape/mesh.py +++ b/CorpusCallosum/shape/mesh.py @@ -16,12 +16,9 @@ from pathlib import Path import lapy -import matplotlib -import matplotlib.pyplot as plt import nibabel as nib import numpy as np import plotly.graph_objects as go -import scipy.interpolate from plotly.io import write_html as plotly_write_html from scipy.ndimage import gaussian_filter1d @@ -269,10 +266,10 @@ def create_CC_mesh_from_contours(contours: list[CCContour], color_sides = True if color_sides: left_side_points, left_side_trias, left_side_colors = _create_cap( - left_side_points, left_side_trias, 0 + left_side_points, left_side_trias, contours[0] ) right_side_points, right_side_trias, right_side_colors = _create_cap( - right_side_points, right_side_trias, len(contours) - 1 + right_side_points, right_side_trias, contours[-1] ) # reverse right side trias @@ -304,24 +301,14 @@ class CCMesh(lapy.TriaMesh): Attributes ---------- - contours : list[np.ndarray] - List of numpy arrays containing 2D contour points for each slice. - thickness_values : list[np.ndarray] - List of thickness measurements for each contour point. - start_end_idx : list[tuple[int, int]] - List of tuples containing start and end indices for each contour. - ac_coords : np.ndarray - Coordinates of the anterior commissure. - pc_coords : np.ndarray - Coordinates of the posterior commissure. - resolution : float - Spatial resolution of the mesh. v : np.ndarray Vertex coordinates of the mesh. t : np.ndarray Triangle indices of the mesh. - original_thickness_vertices : list[np.ndarray] - List of vertex indices where thickness was originally measured. + mesh_vertex_colors : np.ndarray + Vertex values for each vertex (CC thickness values) + resolution : float + Spatial resolution of the mesh in millimeters. """ def __init__(self, @@ -339,6 +326,8 @@ def __init__(self, List of face indices or array of shape (M, 3). vertex_values : list or numpy.ndarray, optional Vertex values for each vertex (CC thickness values) + resolution : float, optional + Spatial resolution of the mesh in millimeters, by default 1.0. """ super().__init__(np.vstack(vertices), np.vstack(faces)) self.mesh_vertex_colors = vertex_values @@ -579,266 +568,6 @@ def plot_mesh( webbrowser.open(f"file://{temp_path}") - - - def plot_contour(self, slice_idx: int, output_path: str) -> None: - """Plot a single contour with thickness values. - - Parameters - ---------- - slice_idx : int - Index of the slice to plot. - output_path : str - Path where to save the plot. - - Raises - ------ - ValueError - If the contour for the specified slice is not set. - - Notes - ----- - Creates a 2D visualization with: - - Points colored by thickness values. - - Gray points for missing thickness values. - - Connected contour line. - - Grid, labels, and legend. - """ - self.__make_parent_folder(output_path) - - if self.contours[slice_idx] is None: - raise ValueError(f"Contour for slice {slice_idx} is not set") - - contour = self.contours[slice_idx] - - plt.figure(figsize=(15, 10)) - # Get thickness values for this slice - thickness = self.thickness_values[slice_idx] - - # Plot points with colors based on thickness - for i in range(len(contour)): - if np.isnan(thickness[i]): - plt.plot(contour[i, 0], contour[i, 1], "o", color="gray", markersize=1) - else: - # Map thickness to color from red to yellow - plt.plot( - contour[i, 0], - contour[i, 1], - "o", - color=plt.cm.YlOrRd(thickness[i] / np.nanmax(thickness)), - markersize=1, - ) - - # Connect points with lines - plt.plot(contour[:, 0], contour[:, 1], "-", color="black", alpha=0.3, label="Contour") - plt.axis("equal") - plt.xlabel("X") - plt.ylabel("Y") - plt.title(f"CC contour for slice {slice_idx}") - plt.legend() - plt.grid(True) - plt.tight_layout() - plt.savefig(output_path, dpi=300) - - - - def plot_cc_contour_with_levelsets( - self, - contour_idx: int = 0, - #FIXME: levelpaths is not used - levelpaths: list | None = None, - title: str | None = None, - save_path: str | None = None, - colorbar: bool = True, - mode: str = "p-value", - ) -> matplotlib.figure.Figure: - """Plot a contour with levelset visualization. - - Creates a visualization of a contour with interpolated levelsets, useful for - analyzing the thickness distribution across the corpus callosum. - - Parameters - ---------- - contour_idx : int, default=0 - Index of the contour to plot, by default 0. - levelpaths : list, optional - List of levelset paths. If None, uses stored levelpaths. - title : str, optional - Title for the plot. - save_path : str, optional - Path to save the plot. If None, displays interactively. - colorbar : bool, default=True - Whether to show the colorbar. - mode : {"p-value", "icc"}, default="p-value" - Mode of the plot. - - Returns - ------- - matplotlib.figure.Figure - The created figure object. - """ - - plot_values = np.array(self.thickness_values[contour_idx][~np.isnan(self.thickness_values[contour_idx])])[::-1] - points, trias = make_mesh_from_contour(self.contours[contour_idx], max_volume=0.5, min_angle=25, verbose=False) - - # make points 3D by adding zero - points = np.column_stack([points, np.zeros(len(points))]) - - levelpaths, _ = self._create_levelpaths(contour_idx, points, trias, num_points=len(plot_values)-2) - - outside_contour = self.contours[contour_idx].T - - # Create a grid of points covering the contour area with higher resolution - x_min, x_max = np.min(outside_contour[0]), np.max(outside_contour[0]) - y_min, y_max = np.min(outside_contour[1]), np.max(outside_contour[1]) - margin = 1 - resolution = 0.05 # Higher resolution for smoother interpolation - x_grid, y_grid = np.meshgrid( - np.arange(x_min - margin, x_max + margin, resolution), np.arange(y_min - margin, y_max + margin, resolution) - ) - - # Create a path from the outside contour - contour_path = matplotlib.path.Path(np.column_stack([outside_contour[0], outside_contour[1]])) - - # Check which points are inside the contour - points = np.column_stack([x_grid.flatten(), y_grid.flatten()]) - mask = contour_path.contains_points(points).reshape(x_grid.shape) - - # Collect all levelpath points and their corresponding values - # Extend each levelpath at both ends to improve extrapolation - all_level_points_x = [] - all_level_points_y = [] - all_level_values = [] - - for i, path in enumerate(levelpaths): - if len(path) == 1: - all_level_points_x.append(path[0][0]) - all_level_points_y.append(path[0][1]) - all_level_values.append(plot_values[i]) - continue - - # make levelpath - path = lapy.TriaMesh._TriaMesh__resample_polygon(path, 1000) - - # Extend at the beginning: add point in direction opposite to first segment - first_segment = path[1] - path[0] - # standardize length of first segment - first_segment = first_segment / np.linalg.norm(first_segment) * 10 - extension_start = path[0] - first_segment - all_level_points_x.append(extension_start[0]) - all_level_points_y.append(extension_start[1]) - all_level_values.append(plot_values[i]) - - # Add original path points - for point in path: - all_level_points_x.append(point[0]) - all_level_points_y.append(point[1]) - all_level_values.append(plot_values[i]) - - # Extend at the end: add point in direction of last segment - last_segment = path[-1] - path[-2] - # standardize length of last segment - last_segment = last_segment / np.linalg.norm(last_segment) * 10 - extension_end = path[-1] + last_segment - all_level_points_x.append(extension_end[0]) - all_level_points_y.append(extension_end[1]) - all_level_values.append(plot_values[i]) - - # Convert to numpy arrays - all_level_points_x = np.array(all_level_points_x) - all_level_points_y = np.array(all_level_points_y) - all_level_values = np.array(all_level_values) - - # Use griddata to perform smooth interpolation - using 'linear' instead of 'cubic' - # and properly formatting the input points - grid_values = scipy.interpolate.griddata( - (all_level_points_x, all_level_points_y), all_level_values, (x_grid, y_grid), method="linear", fill_value=0, - ) - - # smooth the grid_values - grid_values = scipy.ndimage.gaussian_filter(grid_values, sigma=5, radius=5) - - # Apply the mask to only show values inside the contour - masked_values = np.where(mask, grid_values, np.nan) - - if mode == "p-value": - # Sample colormaps - colors1 = plt.cm.binary([0.4] * 128) - colors2 = plt.cm.hot(np.linspace(0.8, 0.1, 128)) - elif mode == "icc": - colors1 = plt.cm.Blues(np.linspace(0, 1, 128)) - colors2 = plt.cm.binary([0.4] * 128) - else: - raise ValueError(f"Invalid mode '{mode}'") - - # Combine the color samples - colors = np.vstack((colors2, colors1)) - - # Create a new colormap - cmap = matplotlib.colors.LinearSegmentedColormap.from_list("my_colormap", colors) - - # Plot CC contour with levelsets - fig = plt.figure(figsize=(10, 3)) - # Apply a 10-degree rotation to the entire plot - base = plt.gca().transData - transform = matplotlib.transforms.Affine2D().rotate_deg(10) - transform = transform + base - - # Plot the filled contour with interpolated colors - plt.imshow( - masked_values, - extent=(x_min - margin, x_max + margin, y_min - margin, y_max + margin), - origin="lower", - cmap=cmap, - alpha=1, - interpolation="bilinear", - vmin=0, - vmax=0.10 if mode == "p-value" else 1, - transform=transform, - ) - - plt.imshow( - masked_values, - extent=(x_min - margin, x_max + margin, y_min - margin, y_max + margin), - origin="lower", - cmap=cmap, - alpha=1, - interpolation="bilinear", - vmin=0, - vmax=0.10 if mode == "p-value" else 1, - # norm=LogNorm(vmin=1e-3, vmax=0.1), # Set minimum to avoid log(0) - transform=transform, - ) - - if colorbar: - # Add a colorbar - cbar = plt.colorbar(aspect=10) - if mode == "p-value": - cbar.ax.set_ylim(0.001, 0.054) - cbar.ax.set_yticks([0.0, 0.01, 0.02, 0.03, 0.04, 0.05]) - cbar.set_label("p-value (log scale)") - elif mode == "icc": - cbar.ax.set_ylim(0, 1) - cbar.ax.set_yticks([0, 0.25, 0.5, 0.75, 1]) - cbar.ax.set_label("Intraclass correlation coefficient") - - # Plot the outside contour on top for clear boundary - plt.plot(outside_contour[0], outside_contour[1], "k-", linewidth=2, label="CC Contour", transform=transform) - - plt.axis("equal") - plt.title(title, fontsize=14, fontweight="bold") - # plt.legend(loc='best') - plt.gca().invert_xaxis() - plt.axis("off") - if save_path is not None: - self.__make_parent_folder(save_path) - plt.savefig(save_path, dpi=300) - else: - plt.show() - return fig - - - @staticmethod def __create_cc_viewmat() -> "Matrix44": """Create the view matrix for a nice view of the corpus callosum. @@ -932,22 +661,19 @@ def snap_cc_picture( if fssurf_file: fssurf_file = Path(fssurf_file) else: - fssurf_file = tempfile.NamedTemporaryFile(suffix=".fssurf", delete=True) - self.write_fssurf(fssurf_file.name) + fssurf_file = tempfile.NamedTemporaryFile(suffix=".fssurf", delete=True).name + self.write_fssurf(fssurf_file) if overlay_file: - overlay_path: str | None = Path(overlay_file).name - elif hasattr(self, "mesh_vertex_colors"): - overlay_file = tempfile.NamedTemporaryFile(suffix=".w", delete=True) - # Write thickness values in FreeSurfer .w format - nib.freesurfer.write_morph_data(overlay_file.name, self.mesh_vertex_colors) - overlay_path = overlay_file.name + overlay_file: str | None = Path(overlay_file) else: - overlay_path = None + overlay_file = tempfile.NamedTemporaryFile(suffix=".w", delete=True).name + # Write thickness values in FreeSurfer '*.w' overlay format + self.write_morph_data(overlay_file) snap1( - fssurf_file.name, - overlaypath=overlay_path, + fssurf_file, + overlaypath=overlay_file, view=None, viewmat=self.__create_cc_viewmat(), width=3 * 500, diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index 0c6fa035..fbcf504e 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -265,13 +265,9 @@ def recon_cc_surf_measures_multi( logger.info("Saving template files (contours.txt, thickness_values.txt, " f"thickness_measurement_points.txt) to {template_dir}") for j in range(len(cc_contours)): - io_futures.extend([ - thread_executor().submit(cc_contours[j].save_contour, template_dir / f"contour_{j}.txt"), - thread_executor().submit(cc_contours[j].save_thickness_values, - template_dir / f"thickness_values_{j}.txt"), - thread_executor().submit(cc_contours[j].save_thickness_measurement_points, - template_dir / f"thickness_measurement_points_{j}.txt"), - ]) + # NOTE: this does not seem to be thread-safe, do not parallelize! + cc_contours[j].save_contour(template_dir / f"contour_{j}.txt") + cc_contours[j].save_thickness_values(template_dir / f"thickness_values_{j}.txt") mesh_outputs = ("html", "mesh", "thickness_overlay", "surf", "thickness_image") if len(cc_contours) > 1 and any(subject_dir.has_attribute(f"cc_{n}") for n in mesh_outputs): diff --git a/doc/scripts/cc_visualization.rst b/doc/scripts/cc_visualization.rst index 068a5a2c..b280a953 100644 --- a/doc/scripts/cc_visualization.rst +++ b/doc/scripts/cc_visualization.rst @@ -12,14 +12,13 @@ Usage Examples 3D Visualization ~~~~~~~~~~~~~~~~ -To visualize a 3D template generated by ``fastsurfer_cc.py`` (using ``--slice_selection all --save_template ...``): +To visualize a 3D template generated by ``fastsurfer_cc.py`` (using ``--slice_selection all --save_template ...``), +point the script to the exported template directory: .. code-block:: bash python3 cc_visualization.py \ - --contours /data/templates/sub001/contours.txt \ - --thickness /data/templates/sub001/thickness_values.txt \ - --measurement_points /data/templates/sub001/measurement_points.txt \ + --template_dir /data/templates/sub001/cc_template \ --output_dir /data/visualizations/sub001 2D Visualization @@ -30,11 +29,17 @@ To visualize a 2D template (using ``--slice_selection middle --save_template ... .. code-block:: bash python3 cc_visualization.py \ - --thickness /data/templates/sub001/thickness_values.txt \ - --measurement_points /data/templates/sub001/measurement_points.txt \ + --template_dir /data/templates/sub001/cc_template \ --output_dir /data/visualizations/sub001 \ --twoD +.. note:: + + You can still pass ``--contours``, ``--thickness`` and + ``--measurement_points`` directly when working with standalone files, but + ``--template_dir`` is the recommended way to load the multi-slice templates + produced by ``fastsurfer_cc.py``. + Outputs ------- diff --git a/doc/scripts/contour.rst b/doc/scripts/contour.rst new file mode 100644 index 00000000..faf75aec --- /dev/null +++ b/doc/scripts/contour.rst @@ -0,0 +1,36 @@ +CorpusCallosum: contour.py +========================== + +This module provides the ``CCContour`` class for reading, writing, and +manipulating 2D corpus callosum contours together with per-vertex thickness +values. Typical template outputs (from ``fastsurfer_cc.py --save_template``) +emit one set per slice: + +- ``contour_.txt``: CSV with header ``New contour, anterior_endpoint_idx=, posterior_endpoint_idx=

`` followed by ``x,y`` rows. +- ``thickness_values_.txt``: CSV with header ``thickness`` and one value per contour vertex. +- ``thickness_measurement_points_.txt``: CSV with header ``vertex_idx`` listing the vertices where thickness was measured. + +Key usage patterns +------------------ + +.. code-block:: python + + from CorpusCallosum.shape.contour import CCContour + + contour = CCContour(contour_points, thickness_values, + endpoint_idxs=(anterior_idx, posterior_idx), + resolution=1.0) + contour.fill_thickness_values() # interpolate missing values + contour.smooth_contour(window_size=5) + contour.save_contour("contour_0.txt") + contour.save_thickness_values("thickness_values_0.txt") + contour.save_thickness_measurement_points("thickness_measurement_points_0.txt") + +Reference +--------- + +.. automodule:: CorpusCallosum.shape.contour + :members: CCContour + :undoc-members: + :show-inheritance: + From 4f8a5a592a7bf55952e0784b34447fd969b78333 Mon Sep 17 00:00:00 2001 From: ClePol Date: Wed, 10 Dec 2025 18:47:27 +0100 Subject: [PATCH 47/68] cleaned up visualization script logic, removed unused CC contour code, multiple bugfixes --- CorpusCallosum/cc_visualization.py | 50 ++++-- CorpusCallosum/shape/contour.py | 209 ++++++------------------- CorpusCallosum/shape/mesh.py | 41 ++--- CorpusCallosum/shape/postprocessing.py | 6 +- 4 files changed, 118 insertions(+), 188 deletions(-) diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index d0b94d2b..b3d6ffad 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -10,6 +10,10 @@ from CorpusCallosum.data.read_write import load_fsaverage_data from CorpusCallosum.shape.contour import CCContour from CorpusCallosum.shape.mesh import create_CC_mesh_from_contours +from FastSurferCNN.utils import logging +from FastSurferCNN.utils.logging import get_logger, setup_logging + +logger = get_logger(__name__) def make_parser() -> argparse.ArgumentParser: @@ -79,10 +83,17 @@ def make_parser() -> argparse.ArgumentParser: help="Legend for the colorbar.", metavar="LEGEND") parser.add_argument( - "--twoD", - action="store_true", + "--twoD", + action="store_true", help="Generate 2D visualization instead of 3D mesh.", ) + parser.add_argument( + "-v", + "--verbose", + action="count", + default=0, + help="Enable verbose (pass twice for debug-output).", + ) return parser @@ -171,12 +182,26 @@ def main( # 2D visualization mid_contour = contours[len(contours) // 2] + + + + + # for now, we only support thickness visualization, this is preparing to plot also p-values and icc values + mode = "thickness" + logger.info(f"Writing output to {output_dir / 'cc_thickness_2d.png'}") + + if mode == "thickness": + raw_thickness_values = mid_contour.thickness_values[~np.isnan(mid_contour.thickness_values)] + # values are duplicated because we they have two measurement points per levelpath + raw_thickness_values = raw_thickness_values[len(raw_thickness_values) // 2:] + mid_contour.plot_contour_colorfill( + plot_values=raw_thickness_values, + title=None, + save_path=str(output_dir / "cc_thickness_2d.png"), + colorbar=True, + mode=mode + ) if twoD: - mid_contour.plot_cc_contour_with_levelsets( - title=None, - save_path=str(output_dir / "cc_thickness_2d.png"), - colorbar=True, - ) return 0 # 3D visualization @@ -191,21 +216,28 @@ def main( cc_mesh.plot_mesh(**plot_kwargs) cc_mesh.plot_mesh(output_path=str(output_dir / "cc_mesh.html"), **plot_kwargs) - mid_contour.plot_cc_contour_with_levelsets(save_path=str(output_dir / "midslice_2d.png")) cc_mesh.to_fs_coordinates(vox2ras_tkr=vox2ras_tkr) + logger.info(f"Writing vtk file to {output_dir / 'cc_mesh.vtk'}") cc_mesh.write_vtk(str(output_dir / "cc_mesh.vtk")) + logger.info(f"Writing freesurfer surface file to {output_dir / 'cc_mesh.fssurf'}") cc_mesh.write_fssurf(str(output_dir / "cc_mesh.fssurf")) + logger.info(f"Writing freesurfer overlay file to {output_dir / 'cc_mesh_overlay.curv'}") cc_mesh.write_morph_data(str(output_dir / "cc_mesh_overlay.curv")) try: cc_mesh.snap_cc_picture(str(output_dir / "cc_mesh_snap.png")) + logger.info(f"Writing 3D snapshot image to {output_dir / 'cc_mesh_snap.png'}") except RuntimeError: - return ("The cc_visualization script requires whippersnappy>=1.3.1 to makes screenshots, install with " + logger.warning("The cc_visualization script requires whippersnappy>=1.3.1 to makes screenshots, install with " "`pip install whippersnappy>=1.3.1` !") return 0 if __name__ == "__main__": options = options_parse() + + # Set up logging if verbose mode is enabled + logging.setup_logging(None, options.verbose) # Log to stdout only + sys.exit(main( template_dir=options.template_dir, output_dir=options.output_dir, diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py index aad028b5..1189311a 100644 --- a/CorpusCallosum/shape/contour.py +++ b/CorpusCallosum/shape/contour.py @@ -26,7 +26,6 @@ import FastSurferCNN.utils.logging as logging from CorpusCallosum.shape.endpoint_heuristic import smooth_contour from CorpusCallosum.shape.thickness import cc_thickness, make_mesh_from_contour -from FastSurferCNN.utils.common import suppress_stdout logger = logging.get_logger(__name__) @@ -43,7 +42,6 @@ class CCContour: endpoint_idxs : tuple[int, int] Tuple containing start and end indices for the contour. """ - def __init__( self, @@ -68,7 +66,8 @@ def __init__( raise ValueError(f"Contour must be a 2D array, but is {self.contour.shape}") self.thickness_values = thickness_values if self.contour.shape[0] != len(thickness_values): - raise ValueError(f"Number of contour points ({self.contour.shape[0]}) does not match number of thickness values ({len(thickness_values)})") + raise ValueError(f"Number of contour points ({self.contour.shape[0]}) does not match number \ + of thickness values ({len(thickness_values)})") # write vertex indices where thickness values are not nan self.original_thickness_vertices = np.where(~np.isnan(thickness_values))[0] self.resolution = resolution @@ -78,6 +77,7 @@ def __init__( else: self.endpoint_idxs = endpoint_idxs + def smooth_contour(self, window_size: int = 5) -> None: """Smooth a contour using a moving average filter. @@ -99,6 +99,7 @@ def smooth_contour(self, window_size: int = 5) -> None: x, y = smooth_contour(x, y, window_size) self.contour = np.array([x, y]).T + def copy(self) -> "CCContour": """Copy the contour. """ @@ -131,7 +132,6 @@ def create_levelpaths(self, num_points: int, update_data: bool = True ) -> tuple[list[np.ndarray], list[float]]: - midline_len, thickness, curvature, midline_equi, \ levelpaths, contour_with_thickness, endpoint_idxs = cc_thickness( self.contour, @@ -142,11 +142,12 @@ def create_levelpaths(self, if update_data: self.contour = contour_with_thickness[:, :2] self.thickness_values = contour_with_thickness[:,2] - self.original_thickness_vertices = np.where(~np.isnan(thickness))[0] + self.original_thickness_vertices = np.where(~np.isnan(self.thickness_values))[0] self.endpoint_idxs = endpoint_idxs return levelpaths, thickness + def set_thickness_values(self, thickness_values: np.ndarray, use_measurement_points: bool = False) -> None: """Set the thickness values for the contour. This is useful to update the thickness values for specific plots. @@ -159,106 +160,17 @@ def set_thickness_values(self, thickness_values: np.ndarray, use_measurement_poi Whether to use the measurement points to set the thickness values, by default False. """ if use_measurement_points: - assert len(thickness_values) == len(self.original_thickness_vertices), "Number of thickness values does not match number of measurement points" - self.thickness_values = np.full(len(self.contour), np.nan) - self.thickness_values[self.original_thickness_vertices] = thickness_values + if len(thickness_values) == len(self.original_thickness_vertices): + self.thickness_values = np.full(len(self.contour), np.nan) + self.thickness_values[self.original_thickness_vertices] = thickness_values + else: + raise ValueError("Number of thickness values " + f"does not match number of measurement points {len(self.original_thickness_vertices)}.") else: - assert len(thickness_values) == len(self.contour), "Number of thickness values does not match number of points in the contour" + assert len(thickness_values) == len(self.contour), "Number of thickness values does not match number of " \ + f"points in the contour {len(self.contour)}." self.thickness_values = thickness_values - def _create_levelpaths( - self, - points: np.ndarray, - trias: np.ndarray, - num_points: int | None = None - ) -> tuple[list[np.ndarray], list[float]]: - """Create level paths for thickness measurements. - - Parameters - ---------- - contour_idx : int - Index of the contour to process - points : np.ndarray - Array of shape (N, 2) containing mesh points - trias : np.ndarray - Array of shape (M, 3) containing triangle indices - num_points : int or None, optional - Number of points to sample along the midline, by default None - - Returns - ------- - tuple[list[np.ndarray], list[float]] - - levelpaths : List of arrays containing level path coordinates - - thickness_values : List of thickness values for each level path - - Notes - ----- - The function: - 1. Creates a triangular mesh from the points - 2. Finds boundary points and endpoints - 3. Solves Poisson equation for level sets - 4. Extracts level paths and interpolates thickness values - """ - - with suppress_stdout(): - cc_tria = lapy.TriaMesh(points, trias) - # extract boundary curve - bdr = np.array(cc_tria.boundary_loops()[0]) - - # find index of endpoints in bdr list - iidx1 = np.where(bdr == self.endpoint_idxs[0])[0][0] - iidx2 = np.where(bdr == self.endpoint_idxs[1])[0][0] - - # create boundary condition (0 at endpoints, -1 on one side, 1 on the other): - if iidx1 > iidx2: - tmp = iidx2 - iidx2 = iidx1 - iidx1 = tmp - dcond = np.ones(bdr.shape) - dcond[iidx1] = 0 - dcond[iidx2] = 0 - dcond[iidx1 + 1 : iidx2] = -1 - - # Extract path - with suppress_stdout(): - fem = lapy.Solver(cc_tria) - vfunc = fem.poisson(0, (bdr, dcond)) - if num_points is not None: - # TODO: do midline stuff - level = 0 - midline_equidistant, midline_length = cc_tria.level_path(vfunc, level, n_points=num_points + 2) - midline_equidistant = midline_equidistant[:, :2] - eval_points = midline_equidistant - else: - eval_points = self.contour - gf = lapy.diffgeo.compute_rotated_f(cc_tria, vfunc) - - # interpolate midline to get levels to evaluate - gf_interp = scipy.interpolate.griddata(cc_tria.v[:, 0:2], gf, eval_points, method="nearest") - - # sort by value - sorting_idx_gf = np.argsort(gf_interp) - gf_interp = gf_interp[sorting_idx_gf] - sorted_thickness_values = self.thickness_values[sorting_idx_gf] - - # get levels to evaluate - # level_length = tria.level_length(gf, gf_interp) - - levelpaths = [] - thickness_values = [] - - for i in range(0, len(eval_points)): - level = gf_interp[i] - # levelpath starts at index zero - if level == 0: - continue - lvlpath, lvlpath_length, tria_idx = cc_tria.level_path(gf, level, get_tria_idx=True) - - levelpaths.append(lvlpath) - thickness_values.append(sorted_thickness_values[i]) - - return levelpaths, thickness_values - def fill_thickness_values(self) -> None: """Interpolate missing thickness values using weighted averaging. @@ -319,7 +231,6 @@ def fill_thickness_values(self) -> None: self.thickness_values = thickness - def smooth_thickness_values(self, iterations: int = 1) -> None: """Smooth the thickness values using a Gaussian filter. @@ -339,7 +250,7 @@ def smooth_thickness_values(self, iterations: int = 1) -> None: self.thickness_values[i] = gaussian_filter1d(self.thickness_values[i], sigma=5) - def plot_contour(self, output_path: str) -> None: + def plot_contour(self, output_path: str | None = None) -> None: """Plot a single contour with thickness values. Parameters @@ -355,7 +266,8 @@ def plot_contour(self, output_path: str) -> None: - Connected contour line. - Grid, labels, and legend. """ - self.__make_parent_folder(output_path) + if output_path is not None: + self.__make_parent_folder(output_path) contour = self.contour @@ -386,12 +298,15 @@ def plot_contour(self, output_path: str) -> None: plt.legend() plt.grid(True) plt.tight_layout() - plt.savefig(output_path, dpi=300) - + if output_path is not None: + plt.savefig(output_path, dpi=300) + else: + plt.show() - def plot_cc_contour_with_levelsets( + def plot_contour_colorfill( self, + plot_values: np.ndarray, title: str | None = None, save_path: str | None = None, colorbar: bool = True, @@ -404,17 +319,15 @@ def plot_cc_contour_with_levelsets( Parameters ---------- - contour_idx : int, default=0 - Index of the contour to plot, by default 0. - levelpaths : list, optional - List of levelset paths. If None, uses stored levelpaths. + plot_values : np.ndarray + Array of values to plot on CC from anterior to posterior (left to right in the plot). title : str, optional Title for the plot. save_path : str, optional Path to save the plot. If None, displays interactively. colorbar : bool, default=True Whether to show the colorbar. - mode : {"p-value", "icc"}, default="p-value" + mode : {"p-value", "icc", "thickness"}, default="p-value" Mode of the plot. Returns @@ -422,15 +335,14 @@ def plot_cc_contour_with_levelsets( matplotlib.figure.Figure The created figure object. """ + plot_values = plot_values[::-1] # make sure values are plotted left to right (anterior to posterior) - plot_values = np.array(self.thickness_values[~np.isnan(self.thickness_values)])[::-1] - points, trias = make_mesh_from_contour(self.contour, max_volume=0.5, min_angle=25, verbose=False) + points, _ = make_mesh_from_contour(self.contour, max_volume=0.5, min_angle=25, verbose=False) # make points 3D by adding zero points = np.column_stack([points, np.zeros(len(points))]) levelpaths, _ = self.create_levelpaths(num_points=len(plot_values)-1, update_data=False) - #levelpaths, _ = self._create_levelpaths(points, trias, num_points=len(plot_values)-2) outside_contour = self.contour.T @@ -519,14 +431,17 @@ def plot_cc_contour_with_levelsets( elif mode == "icc": colors1 = plt.cm.Blues(np.linspace(0, 1, 128)) colors2 = plt.cm.binary([0.4] * 128) + elif mode == "thickness": + # Blue to red colormap for thickness values + cmap = plt.cm.coolwarm else: raise ValueError(f"Invalid mode '{mode}'") - # Combine the color samples - colors = np.vstack((colors2, colors1)) - - # Create a new colormap - cmap = matplotlib.colors.LinearSegmentedColormap.from_list("my_colormap", colors) + # Combine the color samples for p-value and icc modes + if mode != "thickness": + colors = np.vstack((colors2, colors1)) + # Create a new colormap + cmap = matplotlib.colors.LinearSegmentedColormap.from_list("my_colormap", colors) # Plot CC contour with levelsets fig = plt.figure(figsize=(10, 3)) @@ -543,8 +458,8 @@ def plot_cc_contour_with_levelsets( cmap=cmap, alpha=1, interpolation="bilinear", - vmin=0, - vmax=0.10 if mode == "p-value" else 1, + vmin=0 if mode != "thickness" else np.nanmin(plot_values), + vmax=0.10 if mode == "p-value" else (1 if mode == "icc" else np.nanmax(plot_values)), transform=transform, ) @@ -555,15 +470,15 @@ def plot_cc_contour_with_levelsets( cmap=cmap, alpha=1, interpolation="bilinear", - vmin=0, - vmax=0.10 if mode == "p-value" else 1, + vmin=0 if mode != "thickness" else np.nanmin(plot_values), + vmax=0.10 if mode == "p-value" else (1 if mode == "icc" else np.nanmax(plot_values)), # norm=LogNorm(vmin=1e-3, vmax=0.1), # Set minimum to avoid log(0) transform=transform, ) if colorbar: # Add a colorbar - cbar = plt.colorbar(aspect=10) + cbar = plt.colorbar(aspect=15) if mode == "p-value": cbar.ax.set_ylim(0.001, 0.054) cbar.ax.set_yticks([0.0, 0.01, 0.02, 0.03, 0.04, 0.05]) @@ -572,6 +487,12 @@ def plot_cc_contour_with_levelsets( cbar.ax.set_ylim(0, 1) cbar.ax.set_yticks([0, 0.25, 0.5, 0.75, 1]) cbar.ax.set_label("Intraclass correlation coefficient") + elif mode == "thickness": + # Set limits based on actual thickness values + thickness_min = np.nanmin(plot_values) + thickness_max = np.nanmax(plot_values) + cbar.ax.set_ylim(thickness_min, thickness_max) + cbar.set_label("Thickness (mm)") # Plot the outside contour on top for clear boundary plt.plot(outside_contour[0], outside_contour[1], "k-", linewidth=2, label="CC Contour", transform=transform) @@ -588,6 +509,7 @@ def plot_cc_contour_with_levelsets( plt.show() return fig + @staticmethod def __make_parent_folder(filename: Path | str) -> None: """Create the parent folder for a file if it doesn't exist. @@ -604,39 +526,7 @@ def __make_parent_folder(filename: Path | str) -> None: """ Path(filename).parent.mkdir(parents=False, exist_ok=True) - @staticmethod - def _load_thickness_measurement_points(filename: str) -> np.ndarray: - """Load thickness measurement points from a CSV file. - - Parameters - ---------- - filename : str - Path to the CSV file containing measurement points. - - Returns - ------- - np.ndarray - Array containing vertex indices where thickness was measured. - - Notes - ----- - The function: - 1. Reads CSV file with format: slice_idx,vertex_idx (legacy) or a single - vertex_idx column (current output format). - 2. Returns a flat array of vertex indices. - """ - data = np.loadtxt(filename, delimiter=",", skiprows=1) - # handle scalar files (single measurement point) - if data.ndim == 0: - return np.array([int(data)]) - if data.ndim == 1: - return data.astype(int) - if data.shape[1] != 1: - raise ValueError("Thickness measurement points file must contain a single vertex_idx column.") - return data[:, 0].astype(int) - - def save_contour(self, output_path: Path | str) -> None: """Save the contours to a CSV file. @@ -664,6 +554,7 @@ def save_contour(self, output_path: Path | str) -> None: for point in self.contour: f.write(f"{point[0]},{point[1]}\n") + def load_contour(self, input_path: str) -> None: """Load contour from a CSV file. @@ -708,6 +599,7 @@ def load_contour(self, input_path: str) -> None: current_points.append([float(x), float(y)]) self.contour = np.array(current_points) + def save_thickness_values(self, output_path: Path | str) -> None: """Save thickness values to a CSV file. @@ -730,6 +622,7 @@ def save_thickness_values(self, output_path: Path | str) -> None: for value in self.thickness_values: f.write(f"{value}\n") + def load_thickness_values( self, input_path: str, diff --git a/CorpusCallosum/shape/mesh.py b/CorpusCallosum/shape/mesh.py index 0a4ec187..09084bc5 100644 --- a/CorpusCallosum/shape/mesh.py +++ b/CorpusCallosum/shape/mesh.py @@ -26,6 +26,7 @@ from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE from CorpusCallosum.shape.contour import CCContour from CorpusCallosum.shape.thickness import make_mesh_from_contour +from FastSurferCNN.utils.common import suppress_stdout try: from pyrr import Matrix44 @@ -670,26 +671,28 @@ def snap_cc_picture( overlay_file = tempfile.NamedTemporaryFile(suffix=".w", delete=True).name # Write thickness values in FreeSurfer '*.w' overlay format self.write_morph_data(overlay_file) + - snap1( - fssurf_file, - overlaypath=overlay_file, - view=None, - viewmat=self.__create_cc_viewmat(), - width=3 * 500, - height=3 * 300, - outpath=output_path, - ambient=0.6, - colorbar_scale=0.5, - colorbar_y=0.88, - colorbar_x=0.19, - brain_scale=2.1, - fthresh=0, - caption="Corpus Callosum thickness (mm)", - caption_y=0.85, - caption_x=0.17, - caption_scale=0.5, - ) + with suppress_stdout(): + snap1( + fssurf_file, + overlaypath=overlay_file, + view=None, + viewmat=self.__create_cc_viewmat(), + width=3 * 500, + height=3 * 300, + outpath=output_path, + ambient=0.6, + colorbar_scale=0.5, + colorbar_y=0.88, + colorbar_x=0.19, + brain_scale=2.1, + fthresh=0, + caption="Corpus Callosum thickness (mm)", + caption_y=0.85, + caption_x=0.17, + caption_scale=0.5, + ) if fssurf_file and hasattr(fssurf_file, "close"): fssurf_file.close() diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index fbcf504e..c899b10e 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -21,7 +21,7 @@ from CorpusCallosum.data.constants import CC_LABEL, FSAVERAGE_MIDDLE, SUBSEGMENT_LABELS from CorpusCallosum.shape.contour import CCContour from CorpusCallosum.shape.endpoint_heuristic import get_endpoints -from CorpusCallosum.shape.mesh import CCMesh, create_CC_mesh_from_contours +from CorpusCallosum.shape.mesh import create_CC_mesh_from_contours from CorpusCallosum.shape.metrics import calculate_cc_index from CorpusCallosum.shape.subsegment_contour import ( get_primary_eigenvector, @@ -221,7 +221,9 @@ def recon_cc_surf_measures_multi( logger.info(f"Calculating CC measurements for slice {slice_idx+1}{progress}") cc_measures, contour_in_as_space_and_thickness, endpoint_idxs = _results contour_in_as_space, thickness_values = np.split(contour_in_as_space_and_thickness, (2,), axis=1) - cc_contours.append(CCContour(contour_in_as_space, thickness_values[:, 0], endpoint_idxs, resolution=vox_size[0])) + cc_contours.append( + CCContour(contour_in_as_space, thickness_values[:, 0], endpoint_idxs, resolution=vox_size[0]) + ) if cc_measures is None: # this should not happen, but just in case logger.warning(f"Slice index {slice_idx+1}{progress} returned result `None`") From 2c924c89926b4efc1c845c79e3e0cb6fd1f01d08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Wed, 10 Dec 2025 10:31:12 +0100 Subject: [PATCH 48/68] Add the types file to merge all recurring types of FastSurferCC Cleanup of - typing - docstrings - function names - variable names --- CorpusCallosum/cc_visualization.py | 4 +- CorpusCallosum/data/read_write.py | 5 +- CorpusCallosum/fastsurfer_cc.py | 104 ++++++------ CorpusCallosum/localization/inference.py | 45 +++--- CorpusCallosum/paint_cc_into_pred.py | 6 +- CorpusCallosum/segmentation/inference.py | 68 ++++---- .../segmentation_postprocessing.py | 22 +-- CorpusCallosum/shape/contour.py | 20 +-- CorpusCallosum/shape/mesh.py | 15 +- CorpusCallosum/shape/postprocessing.py | 142 +++++++++------- CorpusCallosum/shape/subsegment_contour.py | 91 ++++++----- CorpusCallosum/shape/thickness.py | 23 ++- CorpusCallosum/utils/mapping_helpers.py | 152 +++++++----------- CorpusCallosum/utils/types.py | 74 +++++++++ 14 files changed, 421 insertions(+), 350 deletions(-) create mode 100644 CorpusCallosum/utils/types.py diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index b3d6ffad..3ce992fa 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -11,7 +11,7 @@ from CorpusCallosum.shape.contour import CCContour from CorpusCallosum.shape.mesh import create_CC_mesh_from_contours from FastSurferCNN.utils import logging -from FastSurferCNN.utils.logging import get_logger, setup_logging +from FastSurferCNN.utils.logging import get_logger logger = get_logger(__name__) @@ -141,7 +141,7 @@ def load_contours_from_template_dir( num_thickness_values = np.sum(~np.isnan(np.array(thickness_values[1:],dtype=float))) if fsaverage_contour is None: fsaverage_contour = load_fsaverage_cc_template() - # create measurment points (points = 2 x levelpaths) accorindg to number of thickness values + # create measurement points (points = 2 x levelpaths) accorindg to number of thickness values fsaverage_contour.create_levelpaths(num_points=num_thickness_values // 2, update_data=True) current_contour = fsaverage_contour.copy() current_contour.load_thickness_values(thickness_file) diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py index a0621263..e11b8632 100644 --- a/CorpusCallosum/data/read_write.py +++ b/CorpusCallosum/data/read_write.py @@ -16,12 +16,11 @@ from pathlib import Path from typing import TypedDict -import nibabel as nib import numpy as np from numpy import typing as npt import FastSurferCNN.utils.logging as logging -from FastSurferCNN.utils import AffineMatrix4x4 +from FastSurferCNN.utils import AffineMatrix4x4, nibabelImage from FastSurferCNN.utils.parallel import thread_executor @@ -34,7 +33,7 @@ class FSAverageHeader(TypedDict): logger = logging.get_logger(__name__) -def calc_ras_centroids_from_seg(seg_img: nib.analyze.SpatialImage, label_ids: list[int] | None = None) \ +def calc_ras_centroids_from_seg(seg_img: nibabelImage, label_ids: list[int] | None = None) \ -> dict[int, np.ndarray | None]: """Get centroids of segmentation labels in RAS coordinates, accepts any affine/data layout. diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 8cfb2832..ac343ad9 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -25,7 +25,6 @@ import numpy as np import torch from monai.networks.nets import DenseNet -from numpy import typing as npt from scipy.ndimage import affine_transform from CorpusCallosum.data.constants import ( @@ -48,6 +47,7 @@ from CorpusCallosum.segmentation import inference as segmentation_inference from CorpusCallosum.segmentation import segmentation_postprocessing from CorpusCallosum.shape.postprocessing import ( + CCMeasuresDict, SliceSelection, SubdivisionMethod, check_area_changes, @@ -63,7 +63,7 @@ ) from FastSurferCNN.data_loader.conform import conform, is_conform from FastSurferCNN.segstats import HelpFormatter -from FastSurferCNN.utils import AffineMatrix4x4, logging +from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Mask3d, Shape3d, Vector2d, logging, nibabelImage from FastSurferCNN.utils.arg_types import path_or_none from FastSurferCNN.utils.common import SubjectDirectory, find_device from FastSurferCNN.utils.lta import write_lta @@ -72,6 +72,7 @@ from recon_surf.align_points import find_rigid logger = logging.get_logger(__name__) + _TPathLike = TypeVar("_TPathLike", str, Path, Literal[None]) @@ -164,8 +165,8 @@ def _set_help_sid(action): "cost of precision.", ) def _slice_selection(a: str) -> SliceSelection: - if a.lower() in ("middle", "all"): - return a.lower() + if b := a.lower() in ("middle", "all"): + return b return int(a) parser.add_argument( "--slice_selection", @@ -366,7 +367,7 @@ def options_parse() -> argparse.Namespace: return args -def register_centroids_to_fsavg(aseg_nib: nib.analyze.SpatialImage) \ +def register_centroids_to_fsavg(aseg_nib: nibabelImage) \ -> tuple[AffineMatrix4x4, AffineMatrix4x4, AffineMatrix4x4, FSAverageHeader, AffineMatrix4x4]: """Perform centroid-based registration between subject and fsaverage space. @@ -432,12 +433,12 @@ def register_centroids_to_fsavg(aseg_nib: nib.analyze.SpatialImage) \ def localize_ac_pc( - orig_data: np.ndarray, - aseg_nib: nib.analyze.SpatialImage, + orig_data: Image3d, + aseg_nib: nibabelImage, orig2midslice_vox2vox: AffineMatrix4x4, model_localization: DenseNet, - resample_shape: tuple[int, int, int], -) -> tuple[npt.NDArray[float], npt.NDArray[float]]: + resample_shape: Shape3d, +) -> tuple[Vector2d, Vector2d]: """Localize anterior and posterior commissure points in the brain. Uses a trained model to detect AC and PC points in mid-sagittal slices, @@ -447,7 +448,7 @@ def localize_ac_pc( ---------- orig_data : np.ndarray Array of intensity data. - aseg_nib : nibabel.analyze.SpatialImage + aseg_nib : nibabelImage Subject's segmentation image in native subject space. orig2midslice_vox2vox : np.ndarray Transformation matrix from subject/native space to fsaverage space (in lia). @@ -490,12 +491,12 @@ def localize_ac_pc( def segment_cc( - midslices: np.ndarray, - ac_coords: npt.NDArray[float], - pc_coords: npt.NDArray[float], - aseg_nib: "nib.Nifti1Image", + midslices: Image3d, + ac_coords: Vector2d, + pc_coords: Vector2d, + aseg_nib: nibabelImage, model_segmentation: "torch.nn.Module", -) -> tuple[npt.NDArray[bool], npt.NDArray[float]]: +) -> tuple[Mask3d, Image3d]: """Segment the corpus callosum using a trained model. Performs corpus callosum segmentation on mid-sagittal slices using a trained model, with AC-PC points as anatomical @@ -509,7 +510,7 @@ def segment_cc( Anterior commissure coordinates. pc_coords : np.ndarray Posterior commissure coordinates. - aseg_nib : nibabel.Nifti1Image + aseg_nib : nibabelImage Subject's cc_seg_labels image. model_segmentation : torch.nn.Module Trained model for CC cc_seg_labels. @@ -697,8 +698,16 @@ def main( #### setup variables io_futures = [] + # load models + device = find_device(device) + logger.info(f"Using device: {device}") + + logger.info("Loading models") + _model_localization = thread_executor().submit(localization_inference.load_model, device=device) + _model_segmentation = thread_executor().submit(segmentation_inference.load_model, device=device) + _aseg_fut = thread_executor().submit(nib.load, sd.filename_by_attribute("aseg_name")) - orig = cast(nib.analyze.SpatialImage, nib.load(sd.conf_name)) + orig = cast(nibabelImage, nib.load(sd.conf_name)) # check that the image is conformed, i.e. isotropic 1mm voxels, 256^3 size, LIA orientation if not is_conform(orig, vox_size=None, img_size=None, orientation=None): @@ -718,15 +727,7 @@ def main( "center around the mid-sagittal plane)" ) - # load models - device = find_device(device) - logger.info(f"Using device: {device}") - - logger.info("Loading models") - _model_localization = thread_executor().submit(localization_inference.load_model, device=device) - _model_segmentation = thread_executor().submit(segmentation_inference.load_model, device=device) - - aseg_img = cast(nib.analyze.SpatialImage, _aseg_fut.result()) + aseg_img = cast(nibabelImage, _aseg_fut.result()) if not np.allclose(aseg_img.affine, orig.affine): logger.error("Input MRI and segmentation are not aligned! Please check your input files.") @@ -765,7 +766,8 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4: #### do localization and segmentation inference logger.info("Starting AC/PC localization") - target_shape = (slices_to_analyze, fsavg_header["dims"][1], fsavg_header["dims"][2]) + target_shape: tuple[int, int, int] = (slices_to_analyze, fsavg_header["dims"][1], fsavg_header["dims"][2]) + # predict ac and pc coordinates in upright AS space ac_coords, pc_coords = localize_ac_pc( np.asarray(orig.dataobj), aseg_img, @@ -774,8 +776,9 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4: target_shape, ) logger.info("Starting corpus callosum segmentation") - target_shape = (slices_to_analyze + 8, fsavg_header["dims"][1], fsavg_header["dims"][2]) # 8 for context slices - midslices = affine_transform( + # "+ 8" in x-direction for context slices + target_shape: Shape3d = (slices_to_analyze + 8, fsavg_header["dims"][1], fsavg_header["dims"][2]) + midslices: Image3d = affine_transform( np.asarray(orig.dataobj), np.linalg.inv(_orig2midslab_vox2vox(extra_slices=8)), # inverse is required for affine_transform output_shape=target_shape, @@ -823,7 +826,7 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4: ) io_futures.extend(slice_io_futures) - outer_contours = [slice_result['split_contours'][0] for slice_result in slice_results] + outer_contours = [slice_result["split_contours"][0] for slice_result in slice_results] if len(outer_contours) > 1 and not check_area_changes(outer_contours): logger.warning( @@ -831,12 +834,12 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4: ) # Get middle slice result - middle_slice_result = slice_results[len(slice_results) // 2] - if len(middle_slice_result['split_contours']) <= 5: + middle_slice_result: CCMeasuresDict = slice_results[len(slice_results) // 2] + if len(middle_slice_result["split_contours"]) <= 5: cc_subseg_midslice = make_subdivision_mask( - cc_fn_seg_labels.shape[1:], - middle_slice_result['split_contours'], - orig.header.get_zooms(), + (cc_fn_seg_labels.shape[1], cc_fn_seg_labels.shape[2]), + middle_slice_result["split_contours"], + vox_size[1:], ) else: logger.warning("Too many subsegments for lookup table, skipping sub-division of output segmentation.") @@ -928,22 +931,20 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4: additional_metrics["contour_smoothing"] = contour_smoothing additional_metrics["slice_selection"] = slice_selection - # Convert numpy arrays to lists for JSON serialization - output_metrics_middle_slice = convert_numpy_to_json_serializable(output_metrics_middle_slice | additional_metrics) if sd.has_attribute("cc_mid_measures"): - logger.info(f"Saving CC markers to {sd.filename_by_attribute('cc_mid_measures')}") - sd.filename_by_attribute("cc_mid_measures").parent.mkdir(exist_ok=True, parents=True) - with open(sd.filename_by_attribute("cc_mid_measures"), "w") as f: - json.dump(output_metrics_middle_slice, f, indent=4) + io_futures.append(thread_executor().submit( + save_cc_measures_json, + sd.filename_by_attribute('cc_mid_measures'), + output_metrics_middle_slice | additional_metrics, + )) if sd.has_attribute("cc_measures"): - per_slice_output_dict = convert_numpy_to_json_serializable(per_slice_output_dict | additional_metrics) - sd.filename_by_attribute("cc_measures").parent.mkdir(exist_ok=True, parents=True) - # Save slice-wise postprocessing results to JSON - with open(sd.filename_by_attribute("cc_measures"), "w") as f: - json.dump(per_slice_output_dict, f, indent=4) - logger.info(f"Multiple slice post-processing results saved to {sd.filename_by_attribute('cc_measures')}") + io_futures.append(thread_executor().submit( + save_cc_measures_json, + sd.filename_by_attribute("cc_measures"), + per_slice_output_dict | additional_metrics, + )) # save lta to fsaverage space @@ -986,6 +987,15 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4: logger.info(f"CorpusCallosum analysis pipeline completed successfully in {duration:.2f} seconds.") +def save_cc_measures_json(cc_mid_measure_file: Path, metrics: dict[str, object]): + """Save JSON metrics file.""" + # Convert numpy arrays to lists for JSON serialization + logger.info(f"Saving CC markers to {cc_mid_measure_file}") + cc_mid_measure_file.parent.mkdir(exist_ok=True, parents=True) + with open(cc_mid_measure_file, "w") as f: + json.dump(convert_numpy_to_json_serializable(metrics), f, indent=4) + + if __name__ == "__main__": options = options_parse() diff --git a/CorpusCallosum/localization/inference.py b/CorpusCallosum/localization/inference.py index b9b2a7ed..15e99f37 100644 --- a/CorpusCallosum/localization/inference.py +++ b/CorpusCallosum/localization/inference.py @@ -13,17 +13,19 @@ # limitations under the License. from pathlib import Path +from typing import Literal import numpy as np import torch from monai import transforms from monai.networks.nets import DenseNet -from numpy import typing as npt from CorpusCallosum.transforms.localization import CropAroundACPCFixedSize from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML +from CorpusCallosum.utils.types import Points2dType from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults from FastSurferCNN.download_checkpoints import main as download_checkpoints +from FastSurferCNN.utils import Image3d, Vector2d, Vector3d from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT PATCH_SIZE = (64, 64) @@ -97,9 +99,9 @@ def get_transforms() -> transforms.Compose: def preprocess_volume( image_volume: np.ndarray, - center_pt: npt.NDArray[float], + center_pt: Vector3d, transform: transforms.Transform | None = None -) -> dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor | tuple[int, ...]]: """Preprocess a volume for inference. Parameters @@ -114,16 +116,14 @@ def preprocess_volume( Returns ------- - dict[str, torch.Tensor] + dict[str, torch.Tensor | tuple[int, ...]] Dictionary containing preprocessed image tensor. """ if transform is None: transform = get_transforms() - # During training we used AC/PC coordinates, but during inference - # we approximate this by the center of the third ventricle. - # Therefore we put in the third ventricle center as dummy AC/PC coordinates - # for cropping the image. + # During training we used AC/PC coordinates, but during inference we approximate this by the center of the third + # ventricle. Therefore we put in the third ventricle center as dummy AC/PC coordinates for cropping the image. sample = {"image": image_volume[None], "AC_center": center_pt[1:][None], "PC_center": center_pt[1:][None]} # Apply transforms @@ -136,13 +136,13 @@ def preprocess_volume( return transformed -def run_inference( +def predict( model: torch.nn.Module, - image_volume: np.ndarray, + image_volume: Image3d, patch_center: np.ndarray, device: torch.device | None = None, transform: transforms.Transform | None = None - ) -> tuple[npt.NDArray[float], npt.NDArray[float], np.ndarray, tuple[int, int]]: + ) -> tuple[Points2dType, Points2dType, tuple[int, int]]: """ Run inference on an image volume @@ -165,9 +165,7 @@ def run_inference( Predicted PC coordinates. ac_coord : np.ndarray Predicted AC coordinates. - image : np.ndarray - Processed input images. - crop_offsets : tuple[int, int] + crop_offsets : pair of ints Crop offsets (left, top). """ if device is None: @@ -190,31 +188,32 @@ def run_inference( outputs = model(inputs) * torch.as_tensor([PATCH_SIZE + PATCH_SIZE], device=device) t_crops = [(t_dict['crop_left'] + t_dict['crop_top']) * 2] - outs: npt.NDArray[float] = outputs.cpu().numpy() + np.asarray(t_crops, dtype=float) - return outs[:, :2], outs[:, 2:], inputs.cpu().numpy(), (t_dict["crop_left"][0], t_dict["crop_top"][0]) + outs: np.ndarray[tuple[int, Literal[4]], np.dtype[float]] = outputs.cpu().numpy() + np.asarray(t_crops, dtype=float) + crop_offsets: tuple[int, int] = (t_dict["crop_left"][0], t_dict["crop_top"][0]) + return outs[:, :2], outs[:, 2:], crop_offsets def run_inference_on_slice( model: DenseNet, - image_slice: np.ndarray, - center_pt: np.ndarray, + image_slab: Image3d, + center_pt: Vector2d, num_iterations: int = 2, debug_output: str | None = None, -) -> tuple[npt.NDArray[float], npt.NDArray[float]]: +) -> tuple[Vector2d, Vector2d]: """Run inference on a single slice to detect AC and PC points. Parameters ---------- model : torch.nn.Module Trained model for AC-PC detection. - image_slice : np.ndarray + image_slab : np.ndarray 3D image mid-slices to run inference on in RAS. center_pt : np.ndarray Initial center point estimate for cropping. num_iterations : int, default=2 Number of refinement iterations to run. debug_output : str, optional - Path to save debug visualization, by default None. + Path to save debug visualization. Returns ------- @@ -231,7 +230,7 @@ def run_inference_on_slice( crop_left, crop_top = 0, 0 # Run inference for _ in range(num_iterations): - pc_coords, ac_coords, _, (crop_left, crop_top) = run_inference(model, image_slice, center_pt) + pc_coords, ac_coords, (crop_left, crop_top) = predict(model, image_slab, center_pt) center_pt = np.mean(np.stack([ac_coords, pc_coords], axis=0), axis=(0, 1)) # average ac and pc coords across sagittal slices _pc_coords = np.mean(pc_coords, axis=0) @@ -241,7 +240,7 @@ def run_inference_on_slice( import matplotlib.pyplot as plt from matplotlib.patches import Rectangle fig, ax = plt.subplots(1, 1, figsize=(10, 8)) - ax.imshow(image_slice[image_slice.shape[0]//2, :, :], cmap='gray') + ax.imshow(image_slab[image_slab.shape[0] // 2, :, :], cmap='gray') # Plot points on all views ax.scatter(pc_coords[:, 1], pc_coords[:, 0], c='r', marker='x', label='PC') ax.scatter(ac_coords[:, 1], ac_coords[:, 0], c='b', marker='x', label='AC') diff --git a/CorpusCallosum/paint_cc_into_pred.py b/CorpusCallosum/paint_cc_into_pred.py index f39a1c29..420abaea 100644 --- a/CorpusCallosum/paint_cc_into_pred.py +++ b/CorpusCallosum/paint_cc_into_pred.py @@ -263,15 +263,17 @@ def correct_wm_ventricles( if __name__ == "__main__": + from FastSurferCNN.utils import nibabelImage + # Command Line options are error checking done here options = argument_parse() logging.setup_logging() logger.info(f"Reading inputs: {options.input_cc} {options.input_pred}...") - cc_seg_image = cast(nib.analyze.SpatialImage, nib.load(options.input_cc)) + cc_seg_image = cast(nibabelImage, nib.load(options.input_cc)) cc_seg_data = np.asanyarray(cc_seg_image.dataobj) - aseg_image = cast(nib.analyze.SpatialImage, nib.load(options.input_pred)) + aseg_image = cast(nibabelImage, nib.load(options.input_pred)) aseg_data = np.asanyarray(aseg_image.dataobj) def _is_conform(img, dtype, verbose): diff --git a/CorpusCallosum/segmentation/inference.py b/CorpusCallosum/segmentation/inference.py index d3feaa46..242177cd 100644 --- a/CorpusCallosum/segmentation/inference.py +++ b/CorpusCallosum/segmentation/inference.py @@ -13,6 +13,7 @@ # limitations under the License. from collections.abc import Iterator from pathlib import Path +from typing import cast, overload import nibabel as nib import numpy as np @@ -26,6 +27,7 @@ from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults from FastSurferCNN.download_checkpoints import main as download_checkpoints from FastSurferCNN.models.networks import FastSurferVINN +from FastSurferCNN.utils import Image3d, Image4d, Shape2d, Shape3d, Shape4d, Vector2d, nibabelImage from FastSurferCNN.utils.parallel import thread_executor @@ -82,19 +84,19 @@ def load_model(device: torch.device | None = None) -> FastSurferVINN: def run_inference( - model: FastSurferVINN, - image_slice: np.ndarray, - ac_center: np.ndarray, - pc_center: np.ndarray, + model: "torch.nn.Module", + image_slice: Image3d, + ac_center: Vector2d, + pc_center: Vector2d, voxel_size: tuple[float, float], device: torch.device | None = None, transform: transforms.Transform | None = None -) -> tuple[npt.NDArray[int], npt.NDArray[float], npt.NDArray[float]]: +) -> tuple[np.ndarray[Shape4d, np.dtype[int]], Image4d, Image4d]: """Run inference on a single image slice. Parameters ---------- - model : FastSurferVINN + model : torch.nn.Module Trained model. image_slice : np.ndarray LIA-oriented input image as numpy array of shape (L, I, A). @@ -104,11 +106,10 @@ def run_inference( Posterior commissure coordinates. voxel_size : a pair of floats Voxel size of inferior/superior and anterior/posterior direction in mm. - device : torch.device or None, optional - Device to run inference on, by default None. - If None, uses the device of the model. - transform : transforms.Transform or None, optional - Custom transform pipeline, by default None. + device : torch.device, optional + Device to run inference on. If None, uses the device of the model. + transform : transforms.Transform, optional + Custom transform pipeline. Returns ------- @@ -152,11 +153,11 @@ def run_inference( labels = np.pad(_labels.cpu().numpy(), pad_tuples, mode='constant', constant_values=0) softlabels = np.pad(softlabels, pad_tuples, mode='constant', constant_values=0) - return [x.transpose(0, 2, 3, 1) for x in (labels, _inputs.cpu().numpy(), softlabels)] + return tuple(x.transpose(0, 2, 3, 1) for x in (labels, _inputs.cpu().numpy(), softlabels)) def load_validation_data( - path: str | Path + path: str | Path, ) -> tuple[npt.NDArray[str], npt.NDArray[float], npt.NDArray[float], Iterator[int], npt.NDArray[str], list[str]]: """Load validation data from CSV file and compute label widths. @@ -210,7 +211,7 @@ def _load(label_path: str | Path) -> int: int Number of slices containing non-zero labels, or total slices if <= 100 """ - label_img = nib.load(label_path) + label_img = cast(nibabelImage, nib.load(label_path)) if label_img.shape[0] > 100: # check which slices have non-zero values @@ -226,11 +227,16 @@ def _load(label_path: str | Path) -> int: return images, ac_centers, pc_centers, label_widths, labels, subj_ids +@overload +def one_hot_to_label(one_hot: Image4d, label_ids: list[int] | None = None) -> np.ndarray[Shape3d, np.dtype[int]]: ... + +@overload +def one_hot_to_label(one_hot: Image3d, label_ids: list[int] | None = None) -> np.ndarray[Shape2d, np.dtype[int]]: ... def one_hot_to_label( - one_hot: npt.NDArray[float], - label_ids: list[int] | None = None -) -> npt.NDArray[int]: + one_hot: np.ndarray[tuple[int, ...], np.dtype[bool]], + label_ids: list[int] | None = None, +) -> np.ndarray[tuple[int, ...], np.dtype[int]]: """Convert one-hot encoded segmentation to label map. Converts a one-hot encoded segmentation array to discrete labels by taking @@ -238,10 +244,10 @@ def one_hot_to_label( Parameters ---------- - one_hot : npt.NDArray[float] + one_hot : np.ndarray of floats One-hot encoded segmentation array of shape (..., num_classes). - label_ids : list[int] or None, optional - List of label IDs to map classes to. If None, defaults to [0, 192, 250]. + label_ids : array_like of ints, optional + List of label IDs to map classes to. If None, defaults to [0, FORNIX_LABEL, CC_LABEL]. The index in this list corresponds to the class index from argmax. Returns @@ -250,7 +256,8 @@ def one_hot_to_label( Label map with discrete integer labels. """ if label_ids is None: - label_ids = [0, 192, 250] + from CorpusCallosum.data.constants import CC_LABEL, FORNIX_LABEL + label_ids = [0, FORNIX_LABEL, CC_LABEL] label = np.argmax(one_hot, axis=3) if label_ids is not None: @@ -259,21 +266,20 @@ def one_hot_to_label( return label - def run_inference_on_slice( - model: FastSurferVINN, - test_slice: np.ndarray, - ac_center: npt.NDArray[float], - pc_center: npt.NDArray[float], + model: "torch.nn.Module", + test_slab: Image3d, + ac_center: Vector2d, + pc_center: Vector2d, voxel_size: tuple[float, float], -) -> tuple[npt.NDArray[int], np.ndarray, npt.NDArray[float]]: +) -> tuple[np.ndarray[Shape3d, np.dtype[int]], Image4d, Image4d]: """Run inference on a single slice. Parameters ---------- - model : FastSurferVINN + model : torch.nn.Module Trained model for inference. - test_slice : np.ndarray + test_slab : np.ndarray Input image slice. ac_center : npt.NDArray[float] Anterior commissure coordinates (Inferior and Anterior values). @@ -296,7 +302,7 @@ def run_inference_on_slice( ac_center = np.concatenate([np.zeros(1), ac_center]) pc_center = np.concatenate([np.zeros(1), pc_center]) - results, inputs, outputs_soft = run_inference(model, test_slice, ac_center, pc_center, voxel_size) - results = one_hot_to_label(results) + _results, inputs, outputs_soft = run_inference(model, test_slab, ac_center, pc_center, voxel_size) + results = one_hot_to_label(_results) return results, inputs, outputs_soft diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index c7cc970b..68ecdd64 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -20,6 +20,7 @@ import FastSurferCNN.utils.logging as logging from CorpusCallosum.data.constants import CC_LABEL +from FastSurferCNN.utils import Mask3d, Shape3d logger = logging.get_logger(__name__) @@ -438,9 +439,9 @@ def get_cc_volume_contour( def extract_largest_connected_component( - seg_arr: np.ndarray, - max_connection_distance: float = 3.0 -) -> np.ndarray: + seg_arr: Mask3d, + max_connection_distance: float = 3.0, +) -> Mask3d: """Get largest connected component from a binary segmentation array. Parameters @@ -493,14 +494,15 @@ def extract_largest_connected_component( bincount[background] = -1 # Get largest connected component - largest_cc = labels_cc == np.argmax(bincount) + largest_cc = np.equal(labels_cc, np.argmax(bincount)) return largest_cc + def clean_cc_segmentation( - seg_arr: npt.NDArray[int], - max_connection_distance: float = 3.0 -) -> tuple[np.ndarray, np.ndarray]: + seg_arr: np.ndarray[Shape3d, np.dtype[int]], + max_connection_distance: float = 3.0, +) -> tuple[np.ndarray[Shape3d, np.dtype[int]], Mask3d]: """Clean corpus callosum segmentation by removing non-connected components. Parameters @@ -529,12 +531,12 @@ def clean_cc_segmentation( extract_largest = partial(extract_largest_connected_component, max_connection_distance=max_connection_distance) - # Remove non connected components from the CC alone, with minimal connections - mask = seg_arr == CC_LABEL + # Remove non-connected components from the CC alone, with minimal connections + mask = np.equal(seg_arr, CC_LABEL) cc_seg = mask.astype(int) * CC_LABEL cc_label_cleaned = np.concatenate([extract_largest(seg[None]) * CC_LABEL for seg in cc_seg], axis=0) # Add fornix to the CC labels clean_seg = np.where(mask, cc_label_cleaned, seg_arr) - return clean_seg, cc_label_cleaned > 0 + return clean_seg, np.greater(cc_label_cleaned, 0) diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py index 1189311a..c61a474a 100644 --- a/CorpusCallosum/shape/contour.py +++ b/CorpusCallosum/shape/contour.py @@ -66,8 +66,10 @@ def __init__( raise ValueError(f"Contour must be a 2D array, but is {self.contour.shape}") self.thickness_values = thickness_values if self.contour.shape[0] != len(thickness_values): - raise ValueError(f"Number of contour points ({self.contour.shape[0]}) does not match number \ - of thickness values ({len(thickness_values)})") + raise ValueError( + f"Number of contour points ({self.contour.shape[0]}) does not match number of thickness values " + f"({len(thickness_values)})" + ) # write vertex indices where thickness values are not nan self.original_thickness_vertices = np.where(~np.isnan(thickness_values))[0] self.resolution = resolution @@ -77,7 +79,6 @@ def __init__( else: self.endpoint_idxs = endpoint_idxs - def smooth_contour(self, window_size: int = 5) -> None: """Smooth a contour using a moving average filter. @@ -99,12 +100,10 @@ def smooth_contour(self, window_size: int = 5) -> None: x, y = smooth_contour(x, y, window_size) self.contour = np.array([x, y]).T - def copy(self) -> "CCContour": """Copy the contour. """ return CCContour(self.contour.copy(), self.thickness_values.copy(), self.endpoint_idxs, self.resolution) - def get_contour_edge_lengths(self) -> np.ndarray: """Get the lengths of the edges of a contour. @@ -126,8 +125,7 @@ def get_contour_edge_lengths(self) -> np.ndarray: """ edges = np.diff(self.contour, axis=0) return np.sqrt(np.sum(edges**2, axis=1)) - - + def create_levelpaths(self, num_points: int, update_data: bool = True @@ -147,7 +145,6 @@ def create_levelpaths(self, return levelpaths, thickness - def set_thickness_values(self, thickness_values: np.ndarray, use_measurement_points: bool = False) -> None: """Set the thickness values for the contour. This is useful to update the thickness values for specific plots. @@ -231,7 +228,6 @@ def fill_thickness_values(self) -> None: self.thickness_values = thickness - def smooth_thickness_values(self, iterations: int = 1) -> None: """Smooth the thickness values using a Gaussian filter. @@ -248,7 +244,6 @@ def smooth_thickness_values(self, iterations: int = 1) -> None: for i in range(len(self.thickness_values)): if self.thickness_values[i] is not None: self.thickness_values[i] = gaussian_filter1d(self.thickness_values[i], sigma=5) - def plot_contour(self, output_path: str | None = None) -> None: """Plot a single contour with thickness values. @@ -509,7 +504,6 @@ def plot_contour_colorfill( plt.show() return fig - @staticmethod def __make_parent_folder(filename: Path | str) -> None: """Create the parent folder for a file if it doesn't exist. @@ -526,7 +520,6 @@ def __make_parent_folder(filename: Path | str) -> None: """ Path(filename).parent.mkdir(parents=False, exist_ok=True) - def save_contour(self, output_path: Path | str) -> None: """Save the contours to a CSV file. @@ -554,7 +547,6 @@ def save_contour(self, output_path: Path | str) -> None: for point in self.contour: f.write(f"{point[0]},{point[1]}\n") - def load_contour(self, input_path: str) -> None: """Load contour from a CSV file. @@ -599,7 +591,6 @@ def load_contour(self, input_path: str) -> None: current_points.append([float(x), float(y)]) self.contour = np.array(current_points) - def save_thickness_values(self, output_path: Path | str) -> None: """Save thickness values to a CSV file. @@ -622,7 +613,6 @@ def save_thickness_values(self, output_path: Path | str) -> None: for value in self.thickness_values: f.write(f"{value}\n") - def load_thickness_values( self, input_path: str, diff --git a/CorpusCallosum/shape/mesh.py b/CorpusCallosum/shape/mesh.py index 09084bc5..62b36c1d 100644 --- a/CorpusCallosum/shape/mesh.py +++ b/CorpusCallosum/shape/mesh.py @@ -177,10 +177,12 @@ def make_triangles_between_contours(contour1: np.ndarray, contour2: np.ndarray) -def create_CC_mesh_from_contours(contours: list[CCContour], - lr_center: float = 0, - closed: bool = False, - smooth: int = 0) -> None: +def create_CC_mesh_from_contours( + contours: list[CCContour], + lr_center: float = 0, + closed: bool = False, + smooth: int = 0, +) -> "CCMesh": """Create a surface mesh by triangulating between consecutive contours. Parameters @@ -194,6 +196,11 @@ def create_CC_mesh_from_contours(contours: list[CCContour], smooth : int, optional Number of smoothing iterations to apply, by default 0. + Returns + ------- + CCMesh + The joined CCMesh object. + Raises ------ Warning diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index c899b10e..0f7a7ed5 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import concurrent.futures +from copy import copy from functools import partial +from pathlib import Path from typing import Literal, TypedDict, get_args import numpy as np @@ -24,6 +26,7 @@ from CorpusCallosum.shape.mesh import create_CC_mesh_from_contours from CorpusCallosum.shape.metrics import calculate_cc_index from CorpusCallosum.shape.subsegment_contour import ( + ContourList, get_primary_eigenvector, hampel_subdivide_contour, subdivide_contour, @@ -31,8 +34,9 @@ transform_to_acpc_standard, ) from CorpusCallosum.shape.thickness import cc_thickness, convert_to_ras +from CorpusCallosum.utils.types import ContourThickness, Points2dType from CorpusCallosum.utils.visualization import plot_contours -from FastSurferCNN.utils import AffineMatrix4x4, Vector2d +from FastSurferCNN.utils import AffineMatrix4x4, Image3d, ScalarType, Shape2d, Shape3d, Vector2d from FastSurferCNN.utils.common import SubjectDirectory, suppress_stdout, update_docstring from FastSurferCNN.utils.parallel import process_executor, thread_executor @@ -91,7 +95,7 @@ class CCMeasuresDict(TypedDict): total_perimeter: float total_area: float total_perimeter: float - split_contours: list[np.ndarray] + split_contours: ContourList midline_equidistant: np.ndarray levelpaths: list[np.ndarray] slice_index: int @@ -121,19 +125,19 @@ def create_sag_slice_vox2vox(slice_idx: int, fsaverage_middle: float) -> AffineM @update_docstring(SubdivisionMethod=str(get_args(SubdivisionMethod))[1:-1]) def recon_cc_surf_measures_multi( - segmentation: np.ndarray, + segmentation: np.ndarray[Shape3d, np.dtype[int]], slice_selection: SliceSelection, - fsavg_vox2ras: np.ndarray, - midslices: np.ndarray, - ac_coords: np.ndarray, - pc_coords: np.ndarray, + fsavg_vox2ras: AffineMatrix4x4, + midslices: Image3d, + ac_coords: Vector2d, + pc_coords: Vector2d, num_thickness_points: int, subdivisions: list[float], subdivision_method: SubdivisionMethod, contour_smoothing: int, subject_dir: SubjectDirectory, vox_size: tuple[float, float, float], - vox2ras_tkr: np.ndarray | None = None, + vox2ras_tkr: AffineMatrix4x4 | None = None, ) -> tuple[list[CCMeasuresDict], list[concurrent.futures.Future]]: """Surface reconstruction and metrics computation of corpus callosum slices based on selection mode. @@ -199,7 +203,6 @@ def recon_cc_surf_measures_multi( num_slices = 1 # Process only the middle slice slices_to_recon = [segmentation.shape[0] // 2] - start_slice = segmentation.shape[0] // 2 elif slice_selection == "all": num_slices = segmentation.shape[0] start_slice = 0 @@ -208,7 +211,6 @@ def recon_cc_surf_measures_multi( else: # specific slice number num_slices = 1 slices_to_recon = [int(slice_selection)] - start_slice = int(slice_selection) _gen_fsavg2slice_vox2vox = partial(create_sag_slice_vox2vox, fsaverage_middle=FSAVERAGE_MIDDLE) per_slice_vox2ras = fsavg_vox2ras @ np.stack(list(map(_gen_fsavg2slice_vox2vox, slices_to_recon)), axis=0) @@ -219,29 +221,28 @@ def recon_cc_surf_measures_multi( for i, (slice_idx, _results) in enumerate(zip(slices_to_recon, per_slice_recon, strict=True)): progress = f" ({i+1} of {num_slices})" if num_slices > 1 else "" logger.info(f"Calculating CC measurements for slice {slice_idx+1}{progress}") - cc_measures, contour_in_as_space_and_thickness, endpoint_idxs = _results - contour_in_as_space, thickness_values = np.split(contour_in_as_space_and_thickness, (2,), axis=1) - cc_contours.append( - CCContour(contour_in_as_space, thickness_values[:, 0], endpoint_idxs, resolution=vox_size[0]) - ) + # unpack values from _results + cc_measures: CCMeasuresDict = _results[0] + contour_in_as_space_and_thickness: ContourThickness = _results[1] + endpoint_idxs: tuple[int, int] = _results[2] + contour_in_as_space: Points2dType = contour_in_as_space_and_thickness[:, :2] + thickness_values: np.ndarray[tuple[int], np.dtype[float]] = contour_in_as_space_and_thickness[:, 2] + + cc_contours.append(CCContour(contour_in_as_space, thickness_values, endpoint_idxs, resolution=vox_size[0])) if cc_measures is None: # this should not happen, but just in case logger.warning(f"Slice index {slice_idx+1}{progress} returned result `None`") slice_cc_measures.append(cc_measures) - - if subject_dir.has_attribute("cc_qc_image"): - qc_img = subject_dir.filename_by_attribute("cc_qc_image") - if logger.getEffectiveLevel() <= logging.DEBUG: - qc_slice_img = (qc_img.parent / f"{qc_img.stem}_slice_{slice_idx}{qc_img.suffix}").with_suffix(".png") - if slice_idx == num_slices // 2: - qc_img = qc_img, qc_slice_img - else: - qc_img = qc_slice_img - - if logger.getEffectiveLevel() <= logging.DEBUG or slice_idx == num_slices // 2: - logger.info(f"Saving segmentation qc image to {qc_img}") - + is_debug = logger.getEffectiveLevel() <= logging.DEBUG + is_midslice = slice_idx == num_slices // 2 + if subject_dir.has_attribute("cc_qc_image") and (is_debug or is_midslice): + qc_imgs: list[Path] = (subject_dir.filename_by_attribute("cc_qc_image"),) + if is_debug: + qc_slice_img = qc_imgs[0].with_suffix(f".slice_{slice_idx}.png") + qc_imgs = (qc_imgs if is_midslice else []) + [qc_slice_img] + + logger.info(f"Saving segmentation qc image to {', '.join(map(str, qc_imgs))}") current_slice_in_volume = midslices.shape[0] // 2 - num_slices // 2 + slice_idx # Create visualization for this slice io_futures.append( @@ -251,7 +252,7 @@ def recon_cc_surf_measures_multi( split_contours=cc_measures["split_contours"], midline_equidistant=cc_measures["midline_equidistant"], levelpaths=cc_measures["levelpaths"], - output_path=qc_img, + output_path=qc_imgs, ac_coords=ac_coords, pc_coords=pc_coords, vox_size=vox_size, @@ -266,16 +267,17 @@ def recon_cc_surf_measures_multi( template_dir.mkdir(parents=True, exist_ok=True) logger.info("Saving template files (contours.txt, thickness_values.txt, " f"thickness_measurement_points.txt) to {template_dir}") + run = thread_executor().submit for j in range(len(cc_contours)): - # NOTE: this does not seem to be thread-safe, do not parallelize! - cc_contours[j].save_contour(template_dir / f"contour_{j}.txt") - cc_contours[j].save_thickness_values(template_dir / f"thickness_values_{j}.txt") + # FIXME: check, if this is fixed (thickness values not nan == 200) + # this does not seem to be thread-safe, do not parallelize! + io_futures.append(run(cc_contours[j].save_contour, template_dir / f"contour_{j}.txt")) + io_futures.append(run(cc_contours[j].save_thickness_values, template_dir / f"thickness_values_{j}.txt")) mesh_outputs = ("html", "mesh", "thickness_overlay", "surf", "thickness_image") if len(cc_contours) > 1 and any(subject_dir.has_attribute(f"cc_{n}") for n in mesh_outputs): - for j in range(len(cc_contours)): - cc_contours[j].fill_thickness_values() - cc_mesh = create_CC_mesh_from_contours(cc_contours, smooth=1) + _cc_contours = thread_executor().map(_resample_thickness, cc_contours) + cc_mesh = create_CC_mesh_from_contours(list(cc_contours), smooth=1) if subject_dir.has_attribute("cc_html"): logger.info(f"Saving CC 3D visualization to {subject_dir.filename_by_attribute('cc_html')}") io_futures.append(thread_executor().submit( @@ -304,7 +306,6 @@ def recon_cc_surf_measures_multi( with suppress_stdout(): cc_mesh.snap_cc_picture(thickness_image_path) - if not slice_cc_measures: logger.error("Error: No valid slices were found for postprocessing") raise ValueError("No valid slices were found for postprocessing") @@ -312,8 +313,15 @@ def recon_cc_surf_measures_multi( return slice_cc_measures, io_futures +def _resample_thickness(contour: CCContour) -> CCContour: + """Resamples the thickness values of contour.""" + _c = copy(contour) + _c.fill_thickness_values() + return _c + + def recon_cc_surf_measure( - segmentation: np.ndarray[tuple[int, int], np.integer], + segmentation: np.ndarray[Shape2d, np.dtype[int]], slice_idx: int, affine: AffineMatrix4x4, ac_coords: Vector2d, @@ -323,7 +331,7 @@ def recon_cc_surf_measure( subdivision_method: SubdivisionMethod, contour_smoothing: int, vox_size: tuple[float, float, float], -) -> tuple[CCMeasuresDict, np.ndarray, tuple[int, int]]: +) -> tuple[CCMeasuresDict, ContourThickness, tuple[int, int]]: """Reconstruct surfaces and compute measures for a single slice for the corpus callosum. Parameters @@ -354,7 +362,7 @@ def recon_cc_surf_measure( measures : CCMeasuresDict Dictionary containing measurements if successful. contour_with_thickness : np.ndarray - Contour points with thickness information. + Contour points with thickness information, shape (3, N) for [x, y, thickness]. endpoint_indices : pair of ints Indices of the anterior and posterior endpoints on the contour. @@ -386,7 +394,7 @@ def recon_cc_surf_measure( contour_ras = convert_to_ras(contour, affine) endpoint_idxs: tuple[int, int] - contour_with_thickness: np.ndarray[tuple[int, Literal[3]], np.floating] + contour_with_thickness: ContourThickness midline_len, thickness, curvature, midline_equi, levelpaths, contour_with_thickness, endpoint_idxs = cc_thickness( contour_ras[1:].T, endpoint_idxs, @@ -404,6 +412,7 @@ def recon_cc_surf_measure( cc_index = calculate_cc_index(contour_in_acpc_space) # Apply different subdivision methods based on user choice + split_contours: ContourList if subdivision_method == "shape": _subdivisions = np.asarray(subdivisions) areas, split_contours = subsegment_midline_orthogonal(midline_equi, _subdivisions, contour_ras[1:], plot=False) @@ -425,6 +434,8 @@ def recon_cc_surf_measure( ac_pt_eigen = ac_pt_eigen[:, 0] areas, split_contours = subdivide_contour(contour_eigen, subdivisions, oriented=True, hline_anchor=ac_pt_eigen) split_contours = [rotate_back_eigen(split_contour) for split_contour in split_contours] + else: + raise ValueError(f"Invalid subdivision method {subdivision_method}") total_area = np.sum(areas) total_perimeter = np.sum(np.sqrt(np.sum((np.diff(contour_ras[:, 1:], axis=0))**2, axis=1))) @@ -451,8 +462,12 @@ def recon_cc_surf_measure( return measures, contour_with_thickness, endpoint_idxs -def vectorized_line_test(coords_x: np.ndarray, coords_y: np.ndarray, - line_start: np.ndarray, line_end: np.ndarray) -> np.ndarray: +def vectorized_line_test( + coords_x: np.ndarray[tuple[int], np.dtype[ScalarType]], + coords_y: np.ndarray[tuple[int], np.dtype[ScalarType]], + line_start: Vector2d, + line_end: Vector2d, +) -> np.ndarray[tuple[int], np.dtype[bool]]: """Vectorized version of point_relative_to_line for arrays of points. Parameters @@ -471,6 +486,7 @@ def vectorized_line_test(coords_x: np.ndarray, coords_y: np.ndarray, np.ndarray Boolean array where True means point is to the left of the line. """ + # FIXME: rename this function to something more indicative # Vector from line_start to line_end line_vec = np.array(line_end) - np.array(line_start) @@ -484,18 +500,18 @@ def vectorized_line_test(coords_x: np.ndarray, coords_y: np.ndarray, return cross_products > 0 -def get_unique_contour_points(split_contours: list[tuple[np.ndarray, np.ndarray]]) -> list[np.ndarray]: +def get_unique_contour_points(split_contours: ContourList) -> list[Points2dType]: """Get unique contour points from the split contours. Parameters ---------- - split_contours : list[tuple[np.ndarray, np.ndarray]] - List of split contours (subsegmentations), each containing x and y coordinates. + split_contours : ContourList + List of split contours (subsegmentations), each containing x and y coordinates, each of shape (2, N). Returns ------- list[np.ndarray] - List of unique contour points for each subsegment. + List of unique contour points for each subsegment, each of shape (N, 2). Notes ----- @@ -509,11 +525,11 @@ def get_unique_contour_points(split_contours: list[tuple[np.ndarray, np.ndarray] 3. Collects points unique to each subsegment. """ # For each contour point, check if it appears in other contours - unique_contour_points = [] + unique_contour_points: list[Points2dType] = [] for i, contour in enumerate(split_contours): # Get points for this contour - contour_points = np.vstack((contour[0], -contour[1])).T # Shape: (N,2) + contour_points: Points2dType = np.vstack((contour[0], -contour[1])).T # Shape: (N,2) # Check each point against all other contours unique_points = [] @@ -541,21 +557,21 @@ def get_unique_contour_points(split_contours: list[tuple[np.ndarray, np.ndarray] def make_subdivision_mask( - slice_shape: tuple[int, int], - split_contours: list[tuple[np.ndarray, np.ndarray]], + slice_shape: Shape2d, + split_contours: ContourList, vox_size: tuple[float, float, float], -) -> np.ndarray: +) -> np.ndarray[Shape2d, np.dtype[int]]: """Create a mask for subdividing the corpus callosum based on split contours. Parameters ---------- - slice_shape : tuple[int, int] + slice_shape : pair of ints Shape of the slice (rows, cols). - split_contours : list[tuple[np.ndarray, np.ndarray]] + split_contours : ContourList List of contours defining the subdivisions. Each contour is a tuple of x and y coordinates. - vox_size : triplet of floats - The voxel sizes of the image grid in LIA orientation. + vox_size : pair of floats + The voxel sizes of the image grid in AS orientation. Returns ------- @@ -575,7 +591,7 @@ def make_subdivision_mask( """ # unique contour points are the points where sub-division lines were inserted - unique_contour_points = get_unique_contour_points(split_contours) + unique_contour_points: list[Points2dType] = get_unique_contour_points(split_contours) # shape (N, 2) subdivision_segments = unique_contour_points[1:] for s in subdivision_segments: @@ -586,22 +602,24 @@ def make_subdivision_mask( rows, cols = slice_shape y_coords, x_coords = np.mgrid[0:rows, 0:cols] - subsegment_labels_anterior_posterior = SUBSEGMENT_LABELS.copy() - subsegment_labels_anterior_posterior.reverse() + cc_subsegment_lut_anterior_to_posterior = SUBSEGMENT_LABELS.copy() + cc_subsegment_lut_anterior_to_posterior.reverse() # Initialize with first segment label - subdivision_mask = np.full(slice_shape, subsegment_labels_anterior_posterior[0], dtype=np.int32) + subdivision_mask = np.full(slice_shape, cc_subsegment_lut_anterior_to_posterior[0], dtype=np.int32) # Process each subdivision line for segment_idx, segment_points in enumerate(subdivision_segments): - line_start = segment_points[0] / vox_size[0] - line_end = segment_points[-1] / vox_size[0] + # FIXME: names for line_start and line_end? + line_start: Vector2d = segment_points[0] / vox_size + line_end: Vector2d = segment_points[-1] / vox_size # Vectorized test: find all points to the right of this line + # FIXME: line defined by what? Is this inside the polygon or the line from line_start to line_end? points_right_of_line = vectorized_line_test(x_coords, y_coords, line_start, line_end) # All points to the right of this line belong to the next segment or beyond - subdivision_mask[points_right_of_line] = subsegment_labels_anterior_posterior[segment_idx + 1] + subdivision_mask[points_right_of_line] = cc_subsegment_lut_anterior_to_posterior[segment_idx + 1] return subdivision_mask diff --git a/CorpusCallosum/shape/subsegment_contour.py b/CorpusCallosum/shape/subsegment_contour.py index d1a7d112..cd62e139 100644 --- a/CorpusCallosum/shape/subsegment_contour.py +++ b/CorpusCallosum/shape/subsegment_contour.py @@ -13,23 +13,25 @@ # limitations under the License. from collections.abc import Callable -from typing import Literal, TypeVar +from typing import TYPE_CHECKING, Literal import matplotlib.pyplot as plt import numpy as np -from numpy import typing as npt from scipy.spatial import ConvexHull -_TS = TypeVar("_TS", bound=np.number) +from CorpusCallosum.utils.types import ContourList, Points2dType, Polygon2dType, Polygon3dType +from FastSurferCNN.utils import Mask2d, Mask3d, ScalarType, Vector2d, nibabelImage +if TYPE_CHECKING: + import pandas as pd -def minimum_bounding_rectangle(points): +def minimum_bounding_rectangle(points: Points2dType) -> np.ndarray[tuple[Literal[4], Literal[2]], np.dtype[ScalarType]]: """Find the smallest bounding rectangle for a set of points. Parameters ---------- - points : np.ndarray - Array of shape (N, 2) containing point coordinates. + points : array + An array of shape (N, 2) containing point coordinates. Returns ------- @@ -37,13 +39,12 @@ def minimum_bounding_rectangle(points): Array of shape (4, 2) containing coordinates of the bounding box corners. """ pi2 = np.pi / 2.0 - points = points.T + points = np.asarray(points).T # get the convex hull for the points hull_points = points[ConvexHull(points).vertices] # calculate edge angles - edges = np.zeros((len(hull_points) - 1, 2)) edges = hull_points[1:] - hull_points[:-1] angles = np.arctan2(edges[:, 1], edges[:, 0]) @@ -84,17 +85,17 @@ def minimum_bounding_rectangle(points): return rval -def calc_subsegment_areas(split_contours: list[npt.NDArray[_TS]]) -> npt.NDArray[_TS]: +def calc_subsegment_areas(split_contours: ContourList) -> np.ndarray[tuple[int], np.dtype[ScalarType]]: """Calculate area of each subsegment using the shoelace formula. Parameters ---------- - split_contours : list[np.ndarray] + split_contours : list of np.ndarray List of contour arrays, each of shape (2, N). Returns ------- - subsegment_areas : np.ndarray + subsegment_areas : array of floats Array containing the area of each subsegment. """ # calculate area of each split contour using the shoelace formula @@ -105,22 +106,22 @@ def calc_subsegment_areas(split_contours: list[npt.NDArray[_TS]]) -> npt.NDArray def subsegment_midline_orthogonal( - midline: np.ndarray[tuple[int, Literal[2]], np.dtype[float]], + midline: Points2dType, area_weights: np.ndarray[tuple[int], np.dtype[float]], - contour: np.ndarray[tuple[Literal[2], int], np.dtype[_TS]], + contour: Polygon2dType, plot: bool = True, ax=None, extremes=None, -): +) -> tuple[np.ndarray[tuple[int], np.dtype[ScalarType]], ContourList]: """Subsegment contour orthogonally to the midline based on area weights. Parameters ---------- - midline : np.ndarray + midline : array of floats Array of shape (N, 2) containing midline points. - area_weights : np.ndarray + area_weights : array of floats Array of weights for area-based subdivision. - contour : np.ndarray + contour : array of floats Array of shape (2, M) containing contour points in as space. plot : bool, optional Whether to plot the results, by default True. @@ -131,12 +132,14 @@ def subsegment_midline_orthogonal( Returns ------- - subsegment_areas : list of float + subsegment_areas : array of floats List of subsegment areas. split_contours : list of np.ndarray List of contour arrays for each subsegment. - """ + # FIXME: Here and in other places, the order of dimensions is pretty inconsistent, for example: midline is (N, 2), + # but contours are (2, N)... + # FIXME: why does this code return subsegments that include all previous segments? # get points after midline length of splits @@ -155,7 +158,7 @@ def subsegment_midline_orthogonal( edge_ortho_vectors = np.column_stack((-edge_directions[:, 1], edge_directions[:, 0])) edge_ortho_vectors = edge_ortho_vectors / np.linalg.norm(edge_ortho_vectors, axis=1)[:, None] - split_contours = [contour] + split_contours: ContourList = [contour] # FIXME: double loop should be vectorized, see commented code below for an initial attempt (not tested) # also, finding intersections can be done more efficiently, instead of solving linear system for each segment @@ -364,7 +367,9 @@ def subsegment_midline_orthogonal( return calc_subsegment_areas(split_contours), split_contours -def hampel_subdivide_contour(contour, num_rays, plot=False, ax=None): +def hampel_subdivide_contour(contour: Polygon2dType, num_rays: int, plot: bool = False, ax=None) \ + -> tuple[np.ndarray[tuple[int], np.dtype[float]], ContourList]: + # FIXME: needs docstring # Find the extreme points in the x-direction min_x_index = np.argmin(contour[0]) contour = np.roll(contour, -min_x_index, axis=1) @@ -417,7 +422,7 @@ def hampel_subdivide_contour(contour, num_rays, plot=False, ax=None): ray_vectors[0] = -ray_vectors[0] # Subdivision logic - split_contours = [] + split_contours: ContourList = [] for ray_vector in ray_vectors.T: intersections = [] for i in range(contour.shape[1] - 1): @@ -501,14 +506,14 @@ def hampel_subdivide_contour(contour, num_rays, plot=False, ax=None): def subdivide_contour( - contour: np.ndarray, + contour: Polygon2dType, area_weights: list[float], plot: bool = False, ax: plt.Axes | None = None, plot_transform: Callable | None = None, oriented: bool = False, hline_anchor: np.ndarray | None = None -): +) -> tuple[np.ndarray[tuple[int], np.dtype[float]], ContourList]: """Subdivide contour based on area weights using vertical lines. Divides the contour into segments by drawing vertical lines at positions @@ -777,7 +782,11 @@ def subdivide_contour( return calc_subsegment_areas(split_contours), split_contours -def transform_to_acpc_standard(contour_ras, ac_pt_ras, pc_pt_ras): +def transform_to_acpc_standard( + contour_ras: Polygon2dType | Polygon3dType, + ac_pt_ras: Vector2d, + pc_pt_ras: Vector2d, +) -> tuple[Polygon2dType, Vector2d, Vector2d, Callable[[Polygon2dType], Polygon2dType]]: """Transform contour coordinates to AC-PC standard space. Transforms the contour coordinates by: @@ -787,12 +796,12 @@ def transform_to_acpc_standard(contour_ras, ac_pt_ras, pc_pt_ras): Parameters ---------- - contour_ras : np.ndarray + contour_ras : array of floats Array of shape (2, N) or (3, N) containing contour points in RAS space. ac_pt_ras : np.ndarray - Anterior commissure point coordinates in RAS space. + Anterior commissure point coordinates in AS space. pc_pt_ras : np.ndarray - Posterior commissure point coordinates in RAS space. + Posterior commissure point coordinates in AS space. Returns ------- @@ -804,15 +813,14 @@ def transform_to_acpc_standard(contour_ras, ac_pt_ras, pc_pt_ras): PC point in AC-PC space. rotate_back : callable Function to transform points back to RAS space. - """ # translate AC to the origin and PC to (0, ac_pc_dist) translation_matrix = np.array([[1, 0, -ac_pt_ras[0]], [0, 1, -ac_pt_ras[1]], [0, 0, 1]]) - ac_pc_vec = pc_pt_ras - ac_pt_ras + ac_pc_vec: Vector2d = pc_pt_ras - ac_pt_ras ac_pc_dist = np.linalg.norm(ac_pc_vec) - posterior_vector = np.array([-ac_pc_dist, 0]) + posterior_vector: Vector2d = np.array([-ac_pc_dist, 0], dtype=float) # get angle between ac_pc_vec and posterior_vector dot_product = np.dot(ac_pc_vec, posterior_vector) @@ -833,16 +841,17 @@ def transform_to_acpc_standard(contour_ras, ac_pt_ras, pc_pt_ras): else: contour_ras_homogeneous = contour_ras - contour_acpc = (rotation_matrix @ translation_matrix) @ contour_ras_homogeneous + contour_acpc: Polygon2dType = (rotation_matrix @ translation_matrix) @ contour_ras_homogeneous contour_acpc = contour_acpc[:2, :] - def rotate_back(x): + def rotate_back(x: Polygon2dType) -> Polygon2dType: return (np.linalg.inv(rotation_matrix @ translation_matrix) @ np.vstack([x, np.ones(x.shape[1])]))[:2, :] - return contour_acpc, np.array([0, 0]), np.array([-ac_pc_dist, 0]), rotate_back + return contour_acpc, np.array([0, 0], dtype=float), np.array([-ac_pc_dist, 0], dtype=float), rotate_back -def preprocess_cc(cc_label_nib, paths_csv, subj_id): +def preprocess_cc(cc_label_nib: nibabelImage, paths_csv: "pd.DataFrame", subj_id: str) \ + -> tuple[Mask2d, Vector2d, Vector2d]: """Preprocess corpus callosum mask and extract AC/PC coordinates. Parameters @@ -864,8 +873,8 @@ def preprocess_cc(cc_label_nib, paths_csv, subj_id): 2D coordinates of posterior commissure. """ - cc_mask = np.asarray(cc_label_nib.dataobj) == 192 - cc_mask = cc_mask[cc_mask.shape[0] // 2] + _cc_mask: Mask3d = np.asarray(cc_label_nib.dataobj) == 192 + cc_mask: Mask2d = _cc_mask[_cc_mask.shape[0] // 2] posterior_commisure_center = paths_csv.loc[subj_id, "PC_center_r":"PC_center_s"].to_numpy().astype(float) anterior_commisure_center = paths_csv.loc[subj_id, "AC_center_r":"AC_center_s"].to_numpy().astype(float) @@ -876,13 +885,13 @@ def preprocess_cc(cc_label_nib, paths_csv, subj_id): # orientation I, A # rotate image so anterior and posterior commisure are horizontal - AC_2d = anterior_commisure_center[1:] - PC_2d = posterior_commisure_center[1:] + ac_2d = anterior_commisure_center[1:] + pc_2d = posterior_commisure_center[1:] - return cc_mask, AC_2d, PC_2d + return cc_mask, ac_2d, pc_2d -def get_primary_eigenvector(contour_ras): +def get_primary_eigenvector(contour_ras: Polygon2dType) -> tuple[Vector2d, Vector2d]: """Calculate primary eigenvector of contour points using PCA. Computes the principal direction of the contour by: diff --git a/CorpusCallosum/shape/thickness.py b/CorpusCallosum/shape/thickness.py index 15b99503..10fd9c8f 100644 --- a/CorpusCallosum/shape/thickness.py +++ b/CorpusCallosum/shape/thickness.py @@ -19,10 +19,11 @@ from lapy.diffgeo import compute_rotated_f from meshpy import triangle +from CorpusCallosum.utils.types import ContourThickness, Points2dType from FastSurferCNN.utils.common import suppress_stdout -def compute_curvature(path: np.ndarray) -> np.ndarray: +def compute_curvature(path: Points2dType) -> np.ndarray[tuple[int], np.dtype[float]]: """Compute curvature by computing edge angles. Parameters @@ -306,9 +307,7 @@ def make_mesh_from_contour( info.set_points(contour_2d) info.set_facets(facets) # NOTE: crashes if contour has duplicate points !! - mesh = triangle.build( - info, max_volume=max_volume, min_angle=min_angle, verbose=verbose, - ) + mesh = triangle.build(info, max_volume=max_volume, min_angle=min_angle, verbose=verbose) mesh_points = np.array(mesh.points) mesh_trias = np.array(mesh.elements) @@ -317,10 +316,10 @@ def make_mesh_from_contour( def cc_thickness( - contour_2d: np.ndarray, + contour_2d: Points2dType, endpoint_idx: tuple[int, int], - n_points: int = 100 -) -> tuple[float, float, float, np.ndarray, list[np.ndarray], np.ndarray, tuple[int, int]]: + n_points: int = 100, +) -> tuple[float, float, float, Points2dType , list[Points2dType], ContourThickness, tuple[int, int]]: """Calculate corpus callosum thickness using Laplace equation. Parameters @@ -329,8 +328,8 @@ def cc_thickness( Array of shape (N, 2) containing contour points. endpoint_idx : pair of ints Indices of anterior and posterior endpoints in contour. - n_points : int, optional - Number of points for thickness measurement, by default 100. + n_points : int, default=100 + Number of points for thickness measurement. Returns ------- @@ -341,9 +340,9 @@ def cc_thickness( curvature : float Mean absolute curvature in degrees. midline_equidistant : np.ndarray - Equidistant points along the midline in same space as contour2d. + Equidistant points along the midline in same space as contour2d of shape (N, 2). levelpaths : list[np.ndarray] - Level paths for thickness measurement in same space as contour2d. + Level paths for thickness measurement in same space as contour2d, each of shape (N, 2). contour_with_thickness : np.ndarray Contour coordinates with thickness information in same space as contour2d of shape (N+2, 3). endpoint_indices : pair of ints @@ -402,7 +401,7 @@ def cc_thickness( ) # get levels to evaluate - levelpaths_contour_space: list[np.ndarray] = [] + levelpaths_contour_space: list[Points2dType] = [] levelpath_lengths = [] levelpath_tria_idx = [] diff --git a/CorpusCallosum/utils/mapping_helpers.py b/CorpusCallosum/utils/mapping_helpers.py index c36ccd26..8e103dee 100644 --- a/CorpusCallosum/utils/mapping_helpers.py +++ b/CorpusCallosum/utils/mapping_helpers.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Literal +from typing import overload import nibabel as nib import numpy as np @@ -8,24 +8,33 @@ from scipy.ndimage import affine_transform from CorpusCallosum.data.constants import CC_LABEL, FORNIX_LABEL -from FastSurferCNN.utils import AffineMatrix4x4, logging +from CorpusCallosum.utils.types import Polygon3dType +from FastSurferCNN.utils import ( + AffineMatrix4x4, + Image2d, + Image3d, + RotationMatrix3x3, + Shape3d, + Vector2d, + Vector3d, + logging, + nibabelImage, +) from FastSurferCNN.utils.parallel import thread_executor -Vector3D = np.ndarray[tuple[Literal[3]], np.dtype[float]] - logger = logging.get_logger(__name__) def make_midplane_affine( - orig_affine: npt.NDArray[float], + orig_affine: AffineMatrix4x4, slices_to_analyze: int = 1, offset: int = 4, - ) -> npt.NDArray[float]: + ) -> AffineMatrix4x4: """Create affine transformation matrix for midplane slices. Parameters ---------- - orig_affine : np.ndarray + orig_affine : AffineMatrix4x4 Original image affine matrix (4x4). slices_to_analyze : int, default=1 Number of slices to analyze around midplane. @@ -34,7 +43,7 @@ def make_midplane_affine( Returns ------- - np.ndarray + AffineMatrix4x4 4x4 affine matrix for midplane slices. """ # Create translation matrix to center on midplane @@ -47,7 +56,7 @@ def make_midplane_affine( return seg_affine -def correct_nodding(ac_pt: npt.NDArray[float], pc_pt: npt.NDArray[float]) -> npt.NDArray[float]: +def correct_nodding(ac_pt: Vector2d, pc_pt: Vector2d) -> RotationMatrix3x3: """Calculate rotation matrix to correct head nodding. Calculates rotation matrix to align AC-PC line with posterior direction, @@ -55,14 +64,14 @@ def correct_nodding(ac_pt: npt.NDArray[float], pc_pt: npt.NDArray[float]) -> npt Parameters ---------- - ac_pt : np.ndarray - Coordinates of the anterior commissure point. - pc_pt : np.ndarray - Coordinates of the posterior commissure point. + ac_pt : Vector2d + 2D coordinates of the anterior commissure point. + pc_pt : Vector2d + 2D coordinates of the posterior commissure point. Returns ------- - np.ndarray + RotationMatrix 3x3 rotation matrix to align AC-PC line with posterior direction. """ ac_pc_vec = pc_pt - ac_pt @@ -81,7 +90,7 @@ def correct_nodding(ac_pt: npt.NDArray[float], pc_pt: npt.NDArray[float]) -> npt theta = -theta # create rotation matrix for theta - rotation_matrix = np.array( + rotation_matrix: RotationMatrix3x3 = np.array( [ [np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], @@ -92,7 +101,13 @@ def correct_nodding(ac_pt: npt.NDArray[float], pc_pt: npt.NDArray[float]) -> npt return rotation_matrix -def apply_transform_to_pt(pts: npt.NDArray[float], T: npt.NDArray[float], inv: bool = False) -> npt.NDArray[float]: +@overload +def apply_transform_to_pt(pts: Vector3d, T: AffineMatrix4x4, inv: bool = False) -> Vector3d: ... + +@overload +def apply_transform_to_pt(pts: Polygon3dType, T: AffineMatrix4x4, inv: bool = False) -> Polygon3dType: ... + +def apply_transform_to_pt(pts: Vector3d | Polygon3dType, T: AffineMatrix4x4, inv: bool = False): """Apply homogeneous transformation matrix to points. Parameters @@ -120,10 +135,10 @@ def apply_transform_to_pt(pts: npt.NDArray[float], T: npt.NDArray[float], inv: b def calc_mapping_to_standard_space( orig: "nib.Nifti1Image", - ac_coords_3d: Vector3D, - pc_coords_3d: Vector3D, + ac_coords_3d: Vector3d, + pc_coords_3d: Vector3d, orig_fsaverage_vox2vox: AffineMatrix4x4, -) -> tuple[AffineMatrix4x4, Vector3D, Vector3D, Vector3D, Vector3D]: +) -> tuple[AffineMatrix4x4, Vector3d, Vector3d, Vector3d, Vector3d]: """Get transformations to map image to standard space. Parameters @@ -161,10 +176,10 @@ def calc_mapping_to_standard_space( # Copy translation part to y,z axes (usually no translation) nod_correct_3d[1:3, 3] = nod_correct_2d[:2, 2] - ac_coords_after_nodding: Vector3D = apply_transform_to_pt( + ac_coords_after_nodding: Vector3d = apply_transform_to_pt( ac_coords_3d, nod_correct_3d, inv=False, ) - pc_coords_after_nodding: Vector3D = apply_transform_to_pt( + pc_coords_after_nodding: Vector3d = apply_transform_to_pt( pc_coords_3d, nod_correct_3d, inv=False, ) @@ -172,10 +187,10 @@ def calc_mapping_to_standard_space( ac_to_center_translation[:3, 3] = image_center - ac_coords_after_nodding # correct nodding - ac_coords_standardized: Vector3D = apply_transform_to_pt( + ac_coords_standardized: Vector3d = apply_transform_to_pt( ac_coords_after_nodding, ac_to_center_translation, inv=False, ) - pc_coords_standardized: Vector3D = apply_transform_to_pt( + pc_coords_standardized: Vector3d = apply_transform_to_pt( pc_coords_after_nodding, ac_to_center_translation, inv=False, ) @@ -186,10 +201,10 @@ def calc_mapping_to_standard_space( ) # calculate ac & pc in space of mri input image - ac_coords_orig: Vector3D = apply_transform_to_pt( + ac_coords_orig: Vector3d = apply_transform_to_pt( ac_coords_standardized, standardized_to_orig_vox2vox, inv=False, ) - pc_coords_orig: Vector3D = apply_transform_to_pt( + pc_coords_orig: Vector3d = apply_transform_to_pt( pc_coords_standardized, standardized_to_orig_vox2vox, inv=False, ) #FIXME: incorrect docstring @@ -197,7 +212,7 @@ def calc_mapping_to_standard_space( def apply_transform_to_volume( - orig_image: nib.analyze.SpatialImage, + orig_image: nibabelImage, vox2vox: AffineMatrix4x4, affine: AffineMatrix4x4, header: nib.freesurfer.mghformat.MGHHeader | None = None, @@ -209,13 +224,13 @@ def apply_transform_to_volume( Parameters ---------- - orig_image : nibabel.analyze.SpatialImage + orig_image : nibabelImage Input volume. vox2vox : np.ndarray Transformation matrix to apply to the data, this is from input-to-output space. affine : AffineMatrix4x4, optional The vox2ras matrix of the output image, only relevant if output_path is given. - header : nibabel.freesurfer.mghformat.MGHHeader, optional + header : nibabelHeader, optional Header for the output image, only relevant if output_path is given, if None will default to orig_image header. output_path : str or Path, optional If output_path is provided, saves the result under this path. @@ -248,7 +263,7 @@ def apply_transform_to_volume( return resampled -def make_affine(simpleITKImage: sitk.Image) -> npt.NDArray[float]: +def make_affine(simpleITKImage: sitk.Image) -> AffineMatrix4x4: """Create an affine transformation matrix from a SimpleITK image. Parameters @@ -282,33 +297,33 @@ def make_affine(simpleITKImage: sitk.Image) -> npt.NDArray[float]: def map_softlabels_to_orig( - cc_fn_softlabels: npt.NDArray[float], - orig_fsaverage_vox2vox: npt.NDArray[float], - orig: nib.analyze.SpatialImage, + cc_fn_softlabels: Image3d, + orig_fsaverage_vox2vox: AffineMatrix4x4, + orig: nibabelImage, orig_space_segmentation_path: str | Path | None = None, fsaverage_middle: int = 128, - cc_subseg_midslice: npt.NDArray[int] | None = None -) -> npt.NDArray[int]: + cc_subseg_midslice: Image2d | None = None +) -> np.ndarray[Shape3d, np.dtype[int]]: """Map soft labels back to original image space and apply post-processing. Parameters ---------- cc_fn_softlabels : np.ndarray Soft label predictions. - orig_fsaverage_vox2vox : np.ndarray + orig_fsaverage_vox2vox : AffineMatrix4x4 Original to fsaverage space transformation. - orig : nibabel.analyze.SpatialImage + orig : nibabelImage Original image. orig_space_segmentation_path : str or Path, optional Path to save segmentation in original space. fsaverage_middle : int, default=128 Middle slice index in fsaverage space. - cc_subseg_midslice : npt.NDArray[int], optional + cc_subseg_midslice : np.ndarray, optional Mask for subdividing regions. Returns ------- - npt.NDArray[int] + np.ndarray Final segmentation in original image space. Notes @@ -324,7 +339,7 @@ def map_softlabels_to_orig( slab2fsaverage_vox2vox[0, 3] = -(fsaverage_middle - slices_to_analyze // 2) slab2orig_vox2vox = orig_fsaverage_vox2vox @ slab2fsaverage_vox2vox - def _map_softlabel_to_orig(i: int, data: np.ndarray) -> np.ndarray: + def _map_softlabel_to_orig(i: int, data: Image3d) -> Image3d: return affine_transform(data, slab2orig_vox2vox, output_shape=orig.shape, order=1, cval=float(i == 0)) _softlabels = np.moveaxis(cc_fn_softlabels, -1, 0) @@ -358,62 +373,3 @@ def _map_softlabel_to_orig(i: int, data: np.ndarray) -> np.ndarray: ) return seg_orig_space - - -def interpolate_midplane( - orig: nib.Nifti1Image, - orig_fsaverage_vox2vox: np.ndarray, - slices_to_analyze: int) -> np.ndarray: - """Interpolates image data at the midplane using a grid of points. - - Parameters - ---------- - orig : nib.Nifti1Image - Original image. - orig_fsaverage_vox2vox : np.ndarray - Original to fsaverage space transformation matrix. - slices_to_analyze : int - Number of slices to analyze around midplane. - - Returns - ------- - np.ndarray - Interpolated image data at midplane. - """ - - # slice_thickness = 9+slices_to_analyze-1 - # make grid of 9 slices in the fsaverage middle - # (cube from 123.5,0.5,0.5 to 132.5,255.5,255.5 (incudling end points, 1mm spacing)) - x_coords = np.linspace( - 124 - slices_to_analyze // 2, - 132 + slices_to_analyze // 2, - 9 + (slices_to_analyze - 1), - endpoint=True, - ) # 9 points from 123.5 to 132.5 - y_coords = np.linspace( - 0, orig.shape[1] - 1, orig.shape[1], endpoint=True - ) # 255 points from 0.5 to 255.5 - z_coords = np.linspace( - 0, orig.shape[2] - 1, orig.shape[2], endpoint=True - ) # 255 points from 0.5 to 255.5 - X, Y, Z = np.meshgrid(x_coords, y_coords, z_coords, indexing="ij") - - # Stack coordinates and add homogeneous coordinate - grid_fsaverage = np.stack([X.ravel(), Y.ravel(), Z.ravel(), np.ones(X.size)]) - - # move grid to orig space by applying transform - grid_orig = np.linalg.inv(orig_fsaverage_vox2vox) @ grid_fsaverage - - # interpolate grid on orig image - from scipy.ndimage import map_coordinates - - transformed = map_coordinates( - np.asarray(orig.dataobj), - grid_orig[0:3, :], # use only x,y,z coordinates (drop homogeneous coordinate) - order=2, - mode="constant", - cval=0, - prefilter=True, - ).reshape(len(x_coords), len(y_coords), len(z_coords)) - - return transformed diff --git a/CorpusCallosum/utils/types.py b/CorpusCallosum/utils/types.py new file mode 100644 index 00000000..52b45f9c --- /dev/null +++ b/CorpusCallosum/utils/types.py @@ -0,0 +1,74 @@ +from typing import Literal, TypedDict + +from numpy import dtype, ndarray + +from FastSurferCNN.utils import ScalarType + +__all__ = [ + "CCMeasuresDict", + "ContourList", + "ContourThickness", + "Points2dType", + "Points3dType", + "Polygon2dType", + "Polygon3dType", + "SliceSelection", + "SubdivisionMethod", +] + +Polygon2dType = ndarray[tuple[Literal[2], int], dtype[ScalarType]] +Polygon3dType = ndarray[tuple[Literal[3], int], dtype[ScalarType]] +Points2dType = ndarray[tuple[int, Literal[2]], dtype[ScalarType]] +Points3dType = ndarray[tuple[int, Literal[3]], dtype[ScalarType]] +ContourList = list[Polygon2dType] +ContourThickness = ndarray[tuple[Literal[3], int], dtype[ScalarType]] +SliceSelection = Literal["middle", "all"] | int +SubdivisionMethod = Literal["shape", "vertical", "angular", "eigenvector"] + +class CCMeasuresDict(TypedDict): + """TypedDict for corpus callosum measures. + + Attributes + ---------- + cc_index : float + Corpus callosum shape index. + circularity : float + Shape circularity measure. + areas : np.ndarray + Areas of subdivided regions. + midline_length : float + Length along the midline. + thickness : float + Array of thickness measurements. + curvature : float + Array of curvature measurements. + thickness_profile : np.ndarray of type float + Thickness measurements along the contour. + total_area : float + Total area of the CC. + total_perimeter : float + Total perimeter length. + split_contours : list of np.ndarray + Subdivided contour segments in AS-slice coordinates. + midline_equidistant : np.ndarray + Equidistant points along midline in AS-slice coordinates. + levelpaths : list of np.ndarray + Paths for thickness measurements in AS-slice coordinates. + slice_index : int + Index of the processed slice. + """ + cc_index: float + circularity: float + areas: ndarray + midline_length: float + thickness: float + curvature: float + thickness_profile: ndarray[tuple[int], dtype[float]] + total_area: float + total_perimeter: float + total_area: float + total_perimeter: float + split_contours: ContourList + midline_equidistant: ndarray + levelpaths: list[ndarray] + slice_index: int From 005917cd45a4d5ddeaa3847e3149b51fc550a4e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Wed, 10 Dec 2025 10:51:27 +0100 Subject: [PATCH 49/68] - Fix ruff and documentation errors - Add types with documentation - Move contour documentation into contour module --- CorpusCallosum/cc_visualization.py | 5 +- CorpusCallosum/fastsurfer_cc.py | 4 +- CorpusCallosum/shape/__init__.py | 15 ++++++ CorpusCallosum/shape/contour.py | 60 ++++++++++++++++------ CorpusCallosum/shape/postprocessing.py | 56 +------------------- CorpusCallosum/shape/subsegment_contour.py | 5 +- doc/api/CorpusCallosum.shape.rst | 1 + doc/api/CorpusCallosum.utils.rst | 1 + doc/scripts/contour.rst | 36 ------------- 9 files changed, 69 insertions(+), 114 deletions(-) delete mode 100644 doc/scripts/contour.rst diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index 3ce992fa..ac205787 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -10,8 +10,7 @@ from CorpusCallosum.data.read_write import load_fsaverage_data from CorpusCallosum.shape.contour import CCContour from CorpusCallosum.shape.mesh import create_CC_mesh_from_contours -from FastSurferCNN.utils import logging -from FastSurferCNN.utils.logging import get_logger +from FastSurferCNN.utils.logging import get_logger, setup_logging logger = get_logger(__name__) @@ -236,7 +235,7 @@ def main( options = options_parse() # Set up logging if verbose mode is enabled - logging.setup_logging(None, options.verbose) # Log to stdout only + setup_logging(None, options.verbose) # Log to stdout only sys.exit(main( template_dir=options.template_dir, diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index ac343ad9..c52c66f4 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -47,9 +47,6 @@ from CorpusCallosum.segmentation import inference as segmentation_inference from CorpusCallosum.segmentation import segmentation_postprocessing from CorpusCallosum.shape.postprocessing import ( - CCMeasuresDict, - SliceSelection, - SubdivisionMethod, check_area_changes, create_sag_slice_vox2vox, make_subdivision_mask, @@ -61,6 +58,7 @@ calc_mapping_to_standard_space, map_softlabels_to_orig, ) +from CorpusCallosum.utils.types import CCMeasuresDict, SliceSelection, SubdivisionMethod from FastSurferCNN.data_loader.conform import conform, is_conform from FastSurferCNN.segstats import HelpFormatter from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Mask3d, Shape3d, Vector2d, logging, nibabelImage diff --git a/CorpusCallosum/shape/__init__.py b/CorpusCallosum/shape/__init__.py index e69de29b..4950a242 100644 --- a/CorpusCallosum/shape/__init__.py +++ b/CorpusCallosum/shape/__init__.py @@ -0,0 +1,15 @@ + +from CorpusCallosum.shape import endpoint_heuristic, mesh, metrics, postprocessing, subsegment_contour, thickness +from CorpusCallosum.shape.contour import CCContour +from CorpusCallosum.shape.mesh import CCMesh + +__all__ = [ + "CCContour", + "CCMesh", + "endpoint_heuristic", + "mesh", + "metrics", + "postprocessing", + "subsegment_contour", + "thickness", +] diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py index c61a474a..a0a18a25 100644 --- a/CorpusCallosum/shape/contour.py +++ b/CorpusCallosum/shape/contour.py @@ -12,6 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +This module provides the ``CCContour`` class for reading, writing, and +manipulating 2D corpus callosum contours together with per-vertex thickness +values. Typical template outputs (from ``fastsurfer_cc.py --save_template``) +emit one set per slice: + +- ``contour_.txt``: CSV with header ``New contour, anterior_endpoint_idx=, posterior_endpoint_idx=

`` followed + by ``x,y`` rows. +- ``thickness_values_.txt``: CSV with header ``thickness`` and one value per contour vertex. +- ``thickness_measurement_points_.txt``: CSV with header ``vertex_idx`` listing the vertices where thickness was + measured. +""" + import re from pathlib import Path from typing import Literal @@ -39,8 +52,23 @@ class CCContour: ---------- contour : np.ndarray Array of shape (N, 2) containing 2D contour points. + thickness_values : np.ndarray + Array of shape (N,) for thickness measurements for each contour point. endpoint_idxs : tuple[int, int] Tuple containing start and end indices for the contour. + + Examples + -------- + >>> from CorpusCallosum.shape.contour import CCContour + >>> + >>> contour = CCContour(contour_points, thickness_values, + >>> endpoint_idxs=(anterior_idx, posterior_idx), + >>> resolution=1.0) + >>> contour.fill_thickness_values() # interpolate missing values + >>> contour.smooth_contour(window_size=5) + >>> contour.save_contour("contour_0.txt") + >>> contour.save_thickness_values("thickness_values_0.txt") + >>> contour.save_thickness_measurement_points("thickness_measurement_points_0.txt") """ def __init__( @@ -58,8 +86,10 @@ def __init__( Array of shape (N, 2) containing 2D contour points. thickness_values : np.ndarray Array of thickness measurements for each contour point. - endpoint_idxs : tuple[int, int] + endpoint_idxs : tuple[int, int], optional Tuple containing start and end indices for the contour. + resolution : float, default=1.0 + The left-right spacing. """ self.contour = contour if self.contour.shape[1] != 2: @@ -68,7 +98,7 @@ def __init__( if self.contour.shape[0] != len(thickness_values): raise ValueError( f"Number of contour points ({self.contour.shape[0]}) does not match number of thickness values " - f"({len(thickness_values)})" + f"({len(thickness_values)})", ) # write vertex indices where thickness values are not nan self.original_thickness_vertices = np.where(~np.isnan(thickness_values))[0] @@ -84,8 +114,6 @@ def smooth_contour(self, window_size: int = 5) -> None: Parameters ---------- - contour_idx : int - Index of the contour to smooth. window_size : int, default=5 Size of the smoothing window. @@ -108,11 +136,6 @@ def copy(self) -> "CCContour": def get_contour_edge_lengths(self) -> np.ndarray: """Get the lengths of the edges of a contour. - Parameters - ---------- - contour_idx : int - Index of the contour to get the edge lengths for. - Returns ------- np.ndarray @@ -161,14 +184,18 @@ def set_thickness_values(self, thickness_values: np.ndarray, use_measurement_poi self.thickness_values = np.full(len(self.contour), np.nan) self.thickness_values[self.original_thickness_vertices] = thickness_values else: - raise ValueError("Number of thickness values " - f"does not match number of measurement points {len(self.original_thickness_vertices)}.") + raise ValueError( + "Number of thickness values does not match number of measurement points " + f"{len(self.original_thickness_vertices)}.", + ) else: - assert len(thickness_values) == len(self.contour), "Number of thickness values does not match number of " \ - f"points in the contour {len(self.contour)}." + if len(thickness_values) != len(self.contour): + raise ValueError( + f"The number of thickness values does not match number of points in the contour " + f"{len(self.contour)}.", + ) self.thickness_values = thickness_values - def fill_thickness_values(self) -> None: """Interpolate missing thickness values using weighted averaging. @@ -368,7 +395,6 @@ def plot_contour_colorfill( # add third dimension to path path = np.column_stack([path, np.zeros(len(path))]) - if len(path) == 1: all_level_points_x.append(path[0][0]) all_level_points_y.append(path[0][1]) @@ -662,5 +688,7 @@ def load_thickness_values( f"{len(self.original_thickness_vertices)} does not match the number of set thickness values " f"{np.sum(~np.isnan(values))}." ) + else: + raise ValueError(f"Number of thickness values in {input_path} does not match the vertices of the path!") - self.thickness_values = new_values \ No newline at end of file + self.thickness_values = new_values diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index 0f7a7ed5..348f515e 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -15,7 +15,7 @@ from copy import copy from functools import partial from pathlib import Path -from typing import Literal, TypedDict, get_args +from typing import get_args import numpy as np @@ -34,15 +34,12 @@ transform_to_acpc_standard, ) from CorpusCallosum.shape.thickness import cc_thickness, convert_to_ras -from CorpusCallosum.utils.types import ContourThickness, Points2dType +from CorpusCallosum.utils.types import CCMeasuresDict, ContourThickness, Points2dType, SliceSelection, SubdivisionMethod from CorpusCallosum.utils.visualization import plot_contours from FastSurferCNN.utils import AffineMatrix4x4, Image3d, ScalarType, Shape2d, Shape3d, Vector2d from FastSurferCNN.utils.common import SubjectDirectory, suppress_stdout, update_docstring from FastSurferCNN.utils.parallel import process_executor, thread_executor -SubdivisionMethod = Literal["shape", "vertical", "angular", "eigenvector"] -SliceSelection = Literal["middle", "all"] | int - logger = logging.get_logger(__name__) # assert LIA orientation @@ -52,55 +49,6 @@ LIA_ORIENTATION[2,1] = -1 -class CCMeasuresDict(TypedDict): - """TypedDict for corpus callosum measures. - - Attributes - ---------- - cc_index : float - Corpus callosum shape index. - circularity : float - Shape circularity measure. - areas : np.ndarray - Areas of subdivided regions. - midline_length : float - Length along the midline. - thickness : float - Array of thickness measurements. - curvature : float - Array of curvature measurements. - thickness_profile : np.ndarray of type float - Thickness measurements along the contour. - total_area : float - Total area of the CC. - total_perimeter : float - Total perimeter length. - split_contours : list of np.ndarray - Subdivided contour segments in AS-slice coordinates. - midline_equidistant : np.ndarray - Equidistant points along midline in AS-slice coordinates. - levelpaths : list of np.ndarray - Paths for thickness measurements in AS-slice coordinates. - slice_index : int - Index of the processed slice. - """ - cc_index: float - circularity: float - areas: np.ndarray - midline_length: float - thickness: float - curvature: float - thickness_profile: np.ndarray[tuple[int], np.dtype[float]] - total_area: float - total_perimeter: float - total_area: float - total_perimeter: float - split_contours: ContourList - midline_equidistant: np.ndarray - levelpaths: list[np.ndarray] - slice_index: int - - def create_sag_slice_vox2vox(slice_idx: int, fsaverage_middle: float) -> AffineMatrix4x4: """Create slice-specific slice to full affine transformation matrix. diff --git a/CorpusCallosum/shape/subsegment_contour.py b/CorpusCallosum/shape/subsegment_contour.py index cd62e139..a98037c8 100644 --- a/CorpusCallosum/shape/subsegment_contour.py +++ b/CorpusCallosum/shape/subsegment_contour.py @@ -873,11 +873,12 @@ def preprocess_cc(cc_label_nib: nibabelImage, paths_csv: "pd.DataFrame", subj_id 2D coordinates of posterior commissure. """ + #FIXME: this function is not used anywhere _cc_mask: Mask3d = np.asarray(cc_label_nib.dataobj) == 192 cc_mask: Mask2d = _cc_mask[_cc_mask.shape[0] // 2] - posterior_commisure_center = paths_csv.loc[subj_id, "PC_center_r":"PC_center_s"].to_numpy().astype(float) - anterior_commisure_center = paths_csv.loc[subj_id, "AC_center_r":"AC_center_s"].to_numpy().astype(float) + posterior_commisure_center = paths_csv.loc[subj_id, "PC_center_r": "PC_center_s"].to_numpy().astype(float) + anterior_commisure_center = paths_csv.loc[subj_id, "AC_center_r": "AC_center_s"].to_numpy().astype(float) # adjust LR from label coordinates to orig_up coordinates posterior_commisure_center[0] = 128 diff --git a/doc/api/CorpusCallosum.shape.rst b/doc/api/CorpusCallosum.shape.rst index cd89aedc..f4c059e3 100644 --- a/doc/api/CorpusCallosum.shape.rst +++ b/doc/api/CorpusCallosum.shape.rst @@ -12,3 +12,4 @@ CorpusCallosum.shape thickness subsegment_contour endpoint_heuristic + contour diff --git a/doc/api/CorpusCallosum.utils.rst b/doc/api/CorpusCallosum.utils.rst index a6595d5b..33fe5e04 100644 --- a/doc/api/CorpusCallosum.utils.rst +++ b/doc/api/CorpusCallosum.utils.rst @@ -8,4 +8,5 @@ CorpusCallosum.utils checkpoint mapping_helpers + types visualization diff --git a/doc/scripts/contour.rst b/doc/scripts/contour.rst deleted file mode 100644 index faf75aec..00000000 --- a/doc/scripts/contour.rst +++ /dev/null @@ -1,36 +0,0 @@ -CorpusCallosum: contour.py -========================== - -This module provides the ``CCContour`` class for reading, writing, and -manipulating 2D corpus callosum contours together with per-vertex thickness -values. Typical template outputs (from ``fastsurfer_cc.py --save_template``) -emit one set per slice: - -- ``contour_.txt``: CSV with header ``New contour, anterior_endpoint_idx=, posterior_endpoint_idx=

`` followed by ``x,y`` rows. -- ``thickness_values_.txt``: CSV with header ``thickness`` and one value per contour vertex. -- ``thickness_measurement_points_.txt``: CSV with header ``vertex_idx`` listing the vertices where thickness was measured. - -Key usage patterns ------------------- - -.. code-block:: python - - from CorpusCallosum.shape.contour import CCContour - - contour = CCContour(contour_points, thickness_values, - endpoint_idxs=(anterior_idx, posterior_idx), - resolution=1.0) - contour.fill_thickness_values() # interpolate missing values - contour.smooth_contour(window_size=5) - contour.save_contour("contour_0.txt") - contour.save_thickness_values("thickness_values_0.txt") - contour.save_thickness_measurement_points("thickness_measurement_points_0.txt") - -Reference ---------- - -.. automodule:: CorpusCallosum.shape.contour - :members: CCContour - :undoc-members: - :show-inheritance: - From de1bedf0ea12506e2b7a0242d9ff029f7a1991ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Fri, 12 Dec 2025 17:29:11 +0100 Subject: [PATCH 50/68] Rename create_CC_mesh_from_contours to CCMesh.from_contours and make it a class method Remove vox2ras_tkr from fsaverage information Fix surface saving Fix --slice_selection type function Fix mapping from upright to orig space Change types of ndarrays from python to numpy Scalar types --- CorpusCallosum/cc_visualization.py | 25 +- CorpusCallosum/data/fsaverage_data.json | 28 +- CorpusCallosum/data/read_write.py | 7 +- CorpusCallosum/fastsurfer_cc.py | 52 ++- CorpusCallosum/localization/inference.py | 12 +- CorpusCallosum/segmentation/inference.py | 16 +- .../segmentation_postprocessing.py | 4 +- CorpusCallosum/shape/contour.py | 4 +- CorpusCallosum/shape/endpoint_heuristic.py | 8 +- CorpusCallosum/shape/mesh.py | 369 +++++++++--------- CorpusCallosum/shape/postprocessing.py | 50 ++- CorpusCallosum/shape/subsegment_contour.py | 6 +- CorpusCallosum/shape/thickness.py | 2 +- CorpusCallosum/transforms/segmentation.py | 2 +- CorpusCallosum/utils/mapping_helpers.py | 93 +++-- FastSurferCNN/utils/__init__.py | 1 - env/fastsurfer.yml | 2 +- pyproject.toml | 2 +- 18 files changed, 340 insertions(+), 343 deletions(-) diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index ac205787..c6911729 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -5,11 +5,10 @@ import numpy as np -from CorpusCallosum.data.constants import FSAVERAGE_DATA_PATH +from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE from CorpusCallosum.data.fsaverage_cc_template import load_fsaverage_cc_template -from CorpusCallosum.data.read_write import load_fsaverage_data from CorpusCallosum.shape.contour import CCContour -from CorpusCallosum.shape.mesh import create_CC_mesh_from_contours +from CorpusCallosum.shape.mesh import CCMesh from FastSurferCNN.utils.logging import get_logger, setup_logging logger = get_logger(__name__) @@ -93,9 +92,9 @@ def make_parser() -> argparse.ArgumentParser: default=0, help="Enable verbose (pass twice for debug-output).", ) - return parser + def options_parse() -> argparse.Namespace: """Parse command line arguments for the pipeline.""" parser = make_parser() @@ -103,12 +102,9 @@ def options_parse() -> argparse.Namespace: # Create output directory if it doesn't exist Path(args.output_dir).mkdir(parents=True, exist_ok=True) - return args - - def load_contours_from_template_dir( template_dir: Path, resolution: float, smoothing_window: int ) -> list[CCContour]: @@ -122,7 +118,6 @@ def load_contours_from_template_dir( ) fsaverage_contour = None - contours: list[CCContour] = [] for thickness_file in thickness_files: try: @@ -173,17 +168,12 @@ def main( output_dir = Path(output_dir) color_range = tuple(color_range) if color_range is not None else None - _, _, vox2ras_tkr = load_fsaverage_data(FSAVERAGE_DATA_PATH) - contours = load_contours_from_template_dir( - Path(template_dir), resolution=resolution, smoothing_window=smoothing_window + Path(template_dir), resolution=resolution, smoothing_window=smoothing_window, ) # 2D visualization mid_contour = contours[len(contours) // 2] - - - # for now, we only support thickness visualization, this is preparing to plot also p-values and icc values mode = "thickness" @@ -191,7 +181,7 @@ def main( if mode == "thickness": raw_thickness_values = mid_contour.thickness_values[~np.isnan(mid_contour.thickness_values)] - # values are duplicated because we they have two measurement points per levelpath + # values are duplicated because they have two measurement points per levelpath raw_thickness_values = raw_thickness_values[len(raw_thickness_values) // 2:] mid_contour.plot_contour_colorfill( plot_values=raw_thickness_values, @@ -204,7 +194,7 @@ def main( return 0 # 3D visualization - cc_mesh = create_CC_mesh_from_contours(contours, smooth=0) + cc_mesh = CCMesh.from_contours(contours, smooth=0) plot_kwargs = dict( colormap=colormap, @@ -215,8 +205,7 @@ def main( cc_mesh.plot_mesh(**plot_kwargs) cc_mesh.plot_mesh(output_path=str(output_dir / "cc_mesh.html"), **plot_kwargs) - - cc_mesh.to_fs_coordinates(vox2ras_tkr=vox2ras_tkr) + cc_mesh = cc_mesh.to_fs_coordinates(lr_offset=FSAVERAGE_MIDDLE / resolution) logger.info(f"Writing vtk file to {output_dir / 'cc_mesh.vtk'}") cc_mesh.write_vtk(str(output_dir / "cc_mesh.vtk")) logger.info(f"Writing freesurfer surface file to {output_dir / 'cc_mesh.fssurf'}") diff --git a/CorpusCallosum/data/fsaverage_data.json b/CorpusCallosum/data/fsaverage_data.json index 0fdd17fb..42cf562b 100644 --- a/CorpusCallosum/data/fsaverage_data.json +++ b/CorpusCallosum/data/fsaverage_data.json @@ -58,31 +58,5 @@ -128.0, 128.0 ] - }, - "vox2ras_tkr": [ - [ - -1.0, - 0.0, - 0.0, - 128.0 - ], - [ - 0.0, - 0.0, - 1.0, - -128.0 - ], - [ - 0.0, - -1.0, - 0.0, - 128.0 - ], - [ - 0.0, - 0.0, - 0.0, - 1.0 - ] - ] + } } \ No newline at end of file diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py index e11b8632..fcc38c17 100644 --- a/CorpusCallosum/data/read_write.py +++ b/CorpusCallosum/data/read_write.py @@ -154,7 +154,7 @@ def load_fsaverage_affine(affine_path: str | Path) -> npt.NDArray[float]: return affine_matrix -def load_fsaverage_data(data_path: str | Path) -> tuple[AffineMatrix4x4, FSAverageHeader, AffineMatrix4x4]: +def load_fsaverage_data(data_path: str | Path) -> tuple[AffineMatrix4x4, FSAverageHeader]: """Load fsaverage affine matrix and header fields from static JSON file. Parameters @@ -176,8 +176,6 @@ def load_fsaverage_data(data_path: str | Path) -> tuple[AffineMatrix4x4, FSAvera 3x3 direction cosines matrix. - Pxyz_c : np.ndarray RAS center coordinates [x,y,z]. - vox2ras_tkr : AffineMatrix4x4 - Voxel to RAS tkr-space transformation matrix. Raises ------ @@ -208,7 +206,6 @@ def load_fsaverage_data(data_path: str | Path) -> tuple[AffineMatrix4x4, FSAvera # Convert lists back to numpy arrays affine_matrix = np.array(data["affine"]) - vox2ras_tkr = np.array(data["vox2ras_tkr"]) header_data = FSAverageHeader( dims=data["header"]["dims"], delta=data["header"]["delta"], @@ -220,4 +217,4 @@ def load_fsaverage_data(data_path: str | Path) -> tuple[AffineMatrix4x4, FSAvera if affine_matrix.shape != (4, 4): raise ValueError(f"Expected 4x4 affine matrix, got shape {affine_matrix.shape}") - return affine_matrix, header_data, vox2ras_tkr + return affine_matrix, header_data diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index c52c66f4..a927bf48 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -61,11 +61,11 @@ from CorpusCallosum.utils.types import CCMeasuresDict, SliceSelection, SubdivisionMethod from FastSurferCNN.data_loader.conform import conform, is_conform from FastSurferCNN.segstats import HelpFormatter -from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Mask3d, Shape3d, Vector2d, logging, nibabelImage +from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Image4d, Mask3d, Shape3d, Vector2d, logging, nibabelImage from FastSurferCNN.utils.arg_types import path_or_none from FastSurferCNN.utils.common import SubjectDirectory, find_device from FastSurferCNN.utils.lta import write_lta -from FastSurferCNN.utils.parallel import shutdown_executors, thread_executor +from FastSurferCNN.utils.parallel import get_num_threads, serial_executor, shutdown_executors, thread_executor from FastSurferCNN.utils.parser_defaults import modify_argument from recon_surf.align_points import find_rigid @@ -163,7 +163,7 @@ def _set_help_sid(action): "cost of precision.", ) def _slice_selection(a: str) -> SliceSelection: - if b := a.lower() in ("middle", "all"): + if (b := a.lower()) in ("middle", "all"): return b return int(a) parser.add_argument( @@ -366,7 +366,7 @@ def options_parse() -> argparse.Namespace: def register_centroids_to_fsavg(aseg_nib: nibabelImage) \ - -> tuple[AffineMatrix4x4, AffineMatrix4x4, AffineMatrix4x4, FSAverageHeader, AffineMatrix4x4]: + -> tuple[AffineMatrix4x4, AffineMatrix4x4, AffineMatrix4x4, FSAverageHeader]: """Perform centroid-based registration between subject and fsaverage space. Computes a rigid transformation between the subject's segmentation and fsaverage space @@ -387,8 +387,6 @@ def register_centroids_to_fsavg(aseg_nib: nibabelImage) \ High-resolution fsaverage affine matrix. fsaverage_header : FSAverageHeader FSAverage header fields for LTA writing. - fsaverage_vox2ras_tkr : AffineMatrix4x4 - Voxel to RAS tkr-space transformation matrix. Notes ----- @@ -416,7 +414,7 @@ def register_centroids_to_fsavg(aseg_nib: nibabelImage) \ aseg_zooms = list(nib.as_closest_canonical(aseg_nib).header.get_zooms()[:3]) resolution_trans: AffineMatrix4x4 = np.diagflat([aseg_zooms[0], aseg_zooms[2], aseg_zooms[1], 1]).astype(float) - fsaverage_vox2ras, fsavg_header, vox2ras_tkr = fsaverage_data_future.result() + fsaverage_vox2ras, fsavg_header = fsaverage_data_future.result() fsavg_header["delta"] = np.asarray([aseg_zooms[0], aseg_zooms[2], aseg_zooms[1]]) # vox sizes in lia # fsavg_hires_vox2ras translation should be 128 always (independent of resolution) fsavg_hires_vox2ras: AffineMatrix4x4 = np.concatenate( @@ -427,7 +425,7 @@ def register_centroids_to_fsavg(aseg_nib: nibabelImage) \ aseg2fsavg_vox2vox: AffineMatrix4x4 = np.linalg.inv(fsavg_hires_vox2ras) @ aseg2fsaverage_ras2ras @ aseg_nib.affine logger.info("Centroid registration successful!") - return aseg2fsavg_vox2vox, aseg2fsaverage_ras2ras, fsavg_hires_vox2ras, fsavg_header, vox2ras_tkr + return aseg2fsavg_vox2vox, aseg2fsaverage_ras2ras, fsavg_hires_vox2ras, fsavg_header def localize_ac_pc( @@ -494,7 +492,7 @@ def segment_cc( pc_coords: Vector2d, aseg_nib: nibabelImage, model_segmentation: "torch.nn.Module", -) -> tuple[Mask3d, Image3d]: +) -> tuple[Mask3d, Image4d]: """Segment the corpus callosum using a trained model. Performs corpus callosum segmentation on mid-sagittal slices using a trained model, with AC-PC points as anatomical @@ -518,7 +516,7 @@ def segment_cc( cc_seg_labels : np.ndarray Binary cc_seg_labels of the corpus callosum. cc_softlabels : np.ndarray - Soft cc_seg_labels probabilities. + Soft cc_seg_labels probabilities of shape (H, W, D, C=3). """ pre_clean_segmentation, inputs, cc_softlabels = segmentation_inference.run_inference_on_slice( model_segmentation, @@ -732,9 +730,7 @@ def main( sys.exit(1) logger.info("Performing centroid registration to fsaverage space") - orig2fsavg_vox2vox, orig2fsavg_ras2ras, fsavg_vox2ras, fsavg_header, fsavg_vox2ras_tkr = ( - register_centroids_to_fsavg(aseg_img) - ) + orig2fsavg_vox2vox, orig2fsavg_ras2ras, fsavg_vox2ras, fsavg_header = register_centroids_to_fsavg(aseg_img) # start saving upright volume, this is the image in fsaverage space but not yet oriented via AC-PC if sd.has_attribute("upright_volume"): @@ -754,11 +750,11 @@ def main( affine_x_offset = partial(create_sag_slice_vox2vox, fsaverage_middle=FSAVERAGE_MIDDLE / vox_size[0]) fsavg2midslab_in_vox2vox: AffineMatrix4x4 = affine_x_offset(slices_to_analyze // 2) # first, midslice->fsaverage in vox2vox, then vox2ras in fsaverage space - midslab_vox2ras: AffineMatrix4x4 = fsavg_vox2ras @ np.linalg.inv(fsavg2midslab_in_vox2vox) + fsaverage_midslab_vox2ras: AffineMatrix4x4 = fsavg_vox2ras @ np.linalg.inv(fsavg2midslab_in_vox2vox) # calculate vox2vox for input resampling volumes - def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4: - fsavg2midslab = affine_x_offset(slices_to_analyze // 2 + extra_slices // 2) + def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: + fsavg2midslab = affine_x_offset(slices_to_analyze // 2 + additional_context // 2) # first, orig->fsaverage in vox2vox, then fsaverage->midslab in vox2vox return fsavg2midslab @ orig2fsavg_vox2vox @@ -769,16 +765,16 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4: ac_coords, pc_coords = localize_ac_pc( np.asarray(orig.dataobj), aseg_img, - _orig2midslab_vox2vox(extra_slices=2), + _orig2midslab_vox2vox(additional_context=2), _model_localization.result(), target_shape, ) logger.info("Starting corpus callosum segmentation") - # "+ 8" in x-direction for context slices - target_shape: Shape3d = (slices_to_analyze + 8, fsavg_header["dims"][1], fsavg_header["dims"][2]) + num_context = 8 # 8 extra in x-direction for context slices + target_shape: Shape3d = (slices_to_analyze + num_context, fsavg_header["dims"][1], fsavg_header["dims"][2]) midslices: Image3d = affine_transform( np.asarray(orig.dataobj), - np.linalg.inv(_orig2midslab_vox2vox(extra_slices=8)), # inverse is required for affine_transform + np.linalg.inv(_orig2midslab_vox2vox(additional_context=num_context)), # inverse is required for affine_transform output_shape=target_shape, order=2, # @ClePol unclear, why this is not order=3 mode="constant", @@ -799,7 +795,7 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4: logger.info(f"Saving {name} softlabels to {sd.filename_by_attribute(f'cc_softlabels_{attr}')}") io_futures.append(thread_executor().submit( nib.save, - nib.MGHImage(cc_fn_softlabels[..., i], midslab_vox2ras, orig.header), + nib.MGHImage(cc_fn_softlabels[..., i], fsaverage_midslab_vox2ras, orig.header), sd.filename_by_attribute(f"cc_softlabels_{attr}"), )) @@ -819,7 +815,6 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4: subdivision_method=subdivision_method, contour_smoothing=contour_smoothing, vox_size=vox_size, - vox2ras_tkr=fsavg_vox2ras_tkr, subject_dir=sd, ) io_futures.extend(slice_io_futures) @@ -837,7 +832,7 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4: cc_subseg_midslice = make_subdivision_mask( (cc_fn_seg_labels.shape[1], cc_fn_seg_labels.shape[2]), middle_slice_result["split_contours"], - vox_size[1:], + vox_size[1:3], ) else: logger.warning("Too many subsegments for lookup table, skipping sub-division of output segmentation.") @@ -847,19 +842,21 @@ def _orig2midslab_vox2vox(extra_slices: int) -> AffineMatrix4x4: if sd.has_attribute("cc_segmentation"): io_futures.append(thread_executor().submit( nib.save, - nib.MGHImage(cc_fn_seg_labels, midslab_vox2ras, orig.header), + nib.MGHImage(cc_fn_seg_labels, fsaverage_midslab_vox2ras, orig.header), sd.filename_by_attribute("cc_segmentation"), )) # map soft labels to original space (in parallel because this takes a while, and we only do it to save the labels) if sd.has_attribute("cc_orig_segfile"): - io_futures.append(thread_executor().submit( + # if num_threads is not large enough (>1), this might be blocking ; serial_executor runs the function in submit + executor = thread_executor() if get_num_threads() > 2 else serial_executor() + io_futures.append(executor.submit( map_softlabels_to_orig, cc_fn_softlabels=cc_fn_softlabels, - orig_fsaverage_vox2vox=orig2fsavg_vox2vox, orig=orig, orig_space_segmentation_path=sd.filename_by_attribute("cc_orig_segfile"), - fsaverage_middle=FSAVERAGE_MIDDLE, + orig2slab_vox2vox=_orig2midslab_vox2vox(), cc_subseg_midslice=cc_subseg_midslice, + orig2midslice_vox2vox=affine_x_offset(0) @ orig2fsavg_vox2vox, # orig2fsavg, then full2midslice )) METRICS = [ @@ -1004,6 +1001,7 @@ def save_cc_measures_json(cc_mid_measure_file: Path, metrics: dict[str, object]) conf_name=options.conf_name, aseg_name=options.aseg_name, subject_dir=options.subject_dir, + #FIXME: slice_selection is True/bool slice_selection=options.slice_selection, num_thickness_points=options.num_thickness_points, subdivisions=list(options.subdivisions), diff --git a/CorpusCallosum/localization/inference.py b/CorpusCallosum/localization/inference.py index 15e99f37..1837864a 100644 --- a/CorpusCallosum/localization/inference.py +++ b/CorpusCallosum/localization/inference.py @@ -13,7 +13,7 @@ # limitations under the License. from pathlib import Path -from typing import Literal +from typing import Literal, cast import numpy as np import torch @@ -177,7 +177,7 @@ def predict( # Preprocess t_dict = preprocess_volume(image_volume, patch_center_3d, transform) - transformed_original = t_dict['image'] + transformed_original = cast(torch.Tensor, t_dict["image"]) inputs = transformed_original.to(device) inputs = inputs.transpose(0, 1) @@ -187,9 +187,11 @@ def predict( with torch.no_grad(): outputs = model(inputs) * torch.as_tensor([PATCH_SIZE + PATCH_SIZE], device=device) - t_crops = [(t_dict['crop_left'] + t_dict['crop_top']) * 2] - outs: np.ndarray[tuple[int, Literal[4]], np.dtype[float]] = outputs.cpu().numpy() + np.asarray(t_crops, dtype=float) - crop_offsets: tuple[int, int] = (t_dict["crop_left"][0], t_dict["crop_top"][0]) + crop_left, crop_top = cast(tuple[int, int], t_dict["crop_left"]), cast(tuple[int, int], t_dict["crop_top"]) + t_crops = [(crop_left + crop_top) * 2] + outs: np.ndarray[tuple[int, Literal[4]], np.dtype[np.float_]] + outs = outputs.cpu().numpy() + np.asarray(t_crops, dtype=float) + crop_offsets: tuple[int, int] = (crop_left[0], crop_top[0]) return outs[:, :2], outs[:, 2:], crop_offsets diff --git a/CorpusCallosum/segmentation/inference.py b/CorpusCallosum/segmentation/inference.py index 242177cd..9704b3b4 100644 --- a/CorpusCallosum/segmentation/inference.py +++ b/CorpusCallosum/segmentation/inference.py @@ -91,7 +91,7 @@ def run_inference( voxel_size: tuple[float, float], device: torch.device | None = None, transform: transforms.Transform | None = None -) -> tuple[np.ndarray[Shape4d, np.dtype[int]], Image4d, Image4d]: +) -> tuple[np.ndarray[Shape4d, np.dtype[np.int_]], Image4d, Image4d]: """Run inference on a single image slice. Parameters @@ -228,15 +228,17 @@ def _load(label_path: str | Path) -> int: return images, ac_centers, pc_centers, label_widths, labels, subj_ids @overload -def one_hot_to_label(one_hot: Image4d, label_ids: list[int] | None = None) -> np.ndarray[Shape3d, np.dtype[int]]: ... +def one_hot_to_label(one_hot: Image4d, label_ids: list[int] | None = None) \ + -> np.ndarray[Shape3d, np.dtype[np.int_]]: ... @overload -def one_hot_to_label(one_hot: Image3d, label_ids: list[int] | None = None) -> np.ndarray[Shape2d, np.dtype[int]]: ... +def one_hot_to_label(one_hot: Image3d, label_ids: list[int] | None = None) \ + -> np.ndarray[Shape2d, np.dtype[np.int_]]: ... def one_hot_to_label( - one_hot: np.ndarray[tuple[int, ...], np.dtype[bool]], + one_hot: np.ndarray[tuple[int, ...], np.dtype[np.bool_]], label_ids: list[int] | None = None, -) -> np.ndarray[tuple[int, ...], np.dtype[int]]: +) -> np.ndarray[tuple[int, ...], np.dtype[np.int_]]: """Convert one-hot encoded segmentation to label map. Converts a one-hot encoded segmentation array to discrete labels by taking @@ -257,7 +259,7 @@ def one_hot_to_label( """ if label_ids is None: from CorpusCallosum.data.constants import CC_LABEL, FORNIX_LABEL - label_ids = [0, FORNIX_LABEL, CC_LABEL] + label_ids = [0, CC_LABEL, FORNIX_LABEL] label = np.argmax(one_hot, axis=3) if label_ids is not None: @@ -272,7 +274,7 @@ def run_inference_on_slice( ac_center: Vector2d, pc_center: Vector2d, voxel_size: tuple[float, float], -) -> tuple[np.ndarray[Shape3d, np.dtype[int]], Image4d, Image4d]: +) -> tuple[np.ndarray[Shape3d, np.dtype[np.int_]], Image4d, Image4d]: """Run inference on a single slice. Parameters diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index 68ecdd64..fb62ca94 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -500,9 +500,9 @@ def extract_largest_connected_component( def clean_cc_segmentation( - seg_arr: np.ndarray[Shape3d, np.dtype[int]], + seg_arr: np.ndarray[Shape3d, np.dtype[np.int_]], max_connection_distance: float = 3.0, -) -> tuple[np.ndarray[Shape3d, np.dtype[int]], Mask3d]: +) -> tuple[np.ndarray[Shape3d, np.dtype[np.int_]], Mask3d]: """Clean corpus callosum segmentation by removing non-connected components. Parameters diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py index a0a18a25..272b7a5f 100644 --- a/CorpusCallosum/shape/contour.py +++ b/CorpusCallosum/shape/contour.py @@ -73,8 +73,8 @@ class CCContour: def __init__( self, - contour: np.ndarray[tuple[Literal["N", 2]], np.dtype[float]], - thickness_values: np.ndarray[tuple[Literal["N"]], np.dtype[float]], + contour: np.ndarray[tuple[Literal["N", 2]], np.dtype[np.float_]], + thickness_values: np.ndarray[tuple[Literal["N"]], np.dtype[np.float_]], endpoint_idxs: tuple[int, int] | None = None, resolution: float = 1.0 ): diff --git a/CorpusCallosum/shape/endpoint_heuristic.py b/CorpusCallosum/shape/endpoint_heuristic.py index eb5f6095..2c89bade 100644 --- a/CorpusCallosum/shape/endpoint_heuristic.py +++ b/CorpusCallosum/shape/endpoint_heuristic.py @@ -19,7 +19,7 @@ import skimage.measure from scipy.ndimage import label -from FastSurferCNN.utils import Vector2d +from FastSurferCNN.utils import Mask2d, Vector2d def smooth_contour(x: np.ndarray, y: np.ndarray, window_size: int) -> tuple[np.ndarray, np.ndarray]: @@ -169,7 +169,7 @@ def extract_cc_contour(cc_mask: np.ndarray, contour_smoothing: int = 5) -> np.nd @overload def get_endpoints( - cc_mask: np.ndarray[tuple[int, int], np.dtype[bool]], + cc_mask: Mask2d, ac_2d: Vector2d, pc_2d: Vector2d, resolution: tuple[float, float], @@ -180,7 +180,7 @@ def get_endpoints( @overload def get_endpoints( - cc_mask: np.ndarray[tuple[int, int], np.dtype[bool]], + cc_mask: Mask2d, ac_2d: Vector2d, pc_2d: Vector2d, resolution: tuple[float, float], @@ -190,7 +190,7 @@ def get_endpoints( def get_endpoints( - cc_mask: np.ndarray[tuple[int, int], np.dtype[bool]], + cc_mask: Mask2d, ac_2d: Vector2d, pc_2d: Vector2d, resolution: tuple[float, float], diff --git a/CorpusCallosum/shape/mesh.py b/CorpusCallosum/shape/mesh.py index 62b36c1d..74af3c63 100644 --- a/CorpusCallosum/shape/mesh.py +++ b/CorpusCallosum/shape/mesh.py @@ -14,6 +14,7 @@ import tempfile from pathlib import Path +from typing import TypeVar import lapy import nibabel as nib @@ -23,9 +24,9 @@ from scipy.ndimage import gaussian_filter1d import FastSurferCNN.utils.logging as logging -from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE from CorpusCallosum.shape.contour import CCContour from CorpusCallosum.shape.thickness import make_mesh_from_contour +from FastSurferCNN.utils import nibabelImage from FastSurferCNN.utils.common import suppress_stdout try: @@ -176,124 +177,7 @@ def make_triangles_between_contours(contour1: np.ndarray, contour2: np.ndarray) return np.array(triangles) - -def create_CC_mesh_from_contours( - contours: list[CCContour], - lr_center: float = 0, - closed: bool = False, - smooth: int = 0, -) -> "CCMesh": - """Create a surface mesh by triangulating between consecutive contours. - - Parameters - ---------- - contours : list[CCContour] - List of CCContour objects to create mesh from. - lr_center : float, optional - Center position in the left-right axis, by default 0. - closed : bool, optional - Whether to create a closed mesh by adding caps, by default False. - smooth : int, optional - Number of smoothing iterations to apply, by default 0. - - Returns - ------- - CCMesh - The joined CCMesh object. - - Raises - ------ - Warning - If no valid contours are found. - - Notes - ----- - The function: - 1. Filters out None contours. - 2. Calculates z-coordinates for each slice. - 3. Creates triangles between adjacent contours. - 4. Optionally: - - Creates caps at both ends. - - Applies smoothing. - - Colors caps based on thickness values. - - """ - - # Check that all contours have the same resolution - resolution = contours[0].resolution - for idx, contour in enumerate(contours[1:], start=1): - if not np.isclose(contour.resolution, resolution): - raise ValueError( - f"All contours must have the same resolution. " - f"Expected {resolution}, but contour at index {idx} has {contour.resolution}." - ) - - - # Calculate z coordinates for each slice - z_coordinates = (np.arange(len(contours)) - len(contours) // 2) * contours[0].resolution + lr_center - - # Build vertices list with z-coordinates - vertices = [] - faces = [] - vertex_start_indices = [] # Track starting index for each contour - current_index = 0 - - for i, contour in enumerate(contours): - vertex_start_indices.append(current_index) - vertices.append(np.hstack([contour.contour, np.full((len(contour.contour), 1), z_coordinates[i])])) - - # Check if there's a next valid contour to connect to - if i + 1 < len(contours): - contour2 = contours[i + 1] - faces_between = make_triangles_between_contours(contour.contour, contour2.contour) - faces.append(faces_between + current_index) - - current_index += len(contour.contour) - - vertex_values = np.concatenate([contour.thickness_values for contour in contours]) - - - - if smooth > 0: - tmp_mesh = CCMesh(vertices, faces, vertex_values=vertex_values) - tmp_mesh.smooth_(smooth) - vertices = tmp_mesh.v - faces = tmp_mesh.t - vertex_values = tmp_mesh.mesh_vertex_colors - - if closed: - # Close the mesh by creating caps on both ends - # Left cap (first slice) - use counterclockwise orientation - left_side_points, left_side_trias = make_mesh_from_contour(vertices[: vertex_start_indices[1]][..., :2]) - left_side_points = np.hstack([left_side_points, np.full((len(left_side_points), 1), z_coordinates[0])]) - - # Right cap (last slice) - reverse points for proper orientation - right_side_points, right_side_trias = make_mesh_from_contour(vertices[vertex_start_indices[-1] :][..., :2]) - right_side_points = np.hstack([right_side_points, np.full((len(right_side_points), 1), z_coordinates[-1])]) - - color_sides = True - if color_sides: - left_side_points, left_side_trias, left_side_colors = _create_cap( - left_side_points, left_side_trias, contours[0] - ) - right_side_points, right_side_trias, right_side_colors = _create_cap( - right_side_points, right_side_trias, contours[-1] - ) - - # reverse right side trias - right_side_trias = right_side_trias[:, ::-1] - - left_side_trias = left_side_trias + current_index - current_index += len(left_side_points) - - right_side_trias = right_side_trias + current_index - current_index += len(right_side_points) - - vertices = [vertices, left_side_points, right_side_points] - faces = [faces, left_side_trias, right_side_trias] - vertex_values = [vertex_values, left_side_colors, right_side_colors] - - return CCMesh(vertices, faces, vertex_values=vertex_values, resolution=resolution) +Self = TypeVar('Self', bound='type[CCMesh]') class CCMesh(lapy.TriaMesh): @@ -621,7 +505,8 @@ def snap_cc_picture( self, output_path: Path | str, fssurf_file: Path | str | None = None, - overlay_file: Path | str | None = None + overlay_file: Path | str | None = None, + ref_image: Path | str | nibabelImage | None = None, ) -> None: """Snap a picture of the corpus callosum mesh. @@ -635,6 +520,8 @@ def snap_cc_picture( overlay_file : Path, str, optional Path to a FreeSurfer overlay file to use for the snapshot. If None, the mesh is saved to a temporary file. + ref_image : Path, str, optional + Path to reference image to use for tkr creation. If None, ignores the file for saving. Raises ------ @@ -670,36 +557,38 @@ def snap_cc_picture( fssurf_file = Path(fssurf_file) else: fssurf_file = tempfile.NamedTemporaryFile(suffix=".fssurf", delete=True).name - self.write_fssurf(fssurf_file) + self.write_fssurf(fssurf_file, image=str(ref_image) if isinstance(ref_image, Path) else ref_image) if overlay_file: - overlay_file: str | None = Path(overlay_file) + overlay_file = Path(overlay_file) else: overlay_file = tempfile.NamedTemporaryFile(suffix=".w", delete=True).name # Write thickness values in FreeSurfer '*.w' overlay format self.write_morph_data(overlay_file) - - with suppress_stdout(): - snap1( - fssurf_file, - overlaypath=overlay_file, - view=None, - viewmat=self.__create_cc_viewmat(), - width=3 * 500, - height=3 * 300, - outpath=output_path, - ambient=0.6, - colorbar_scale=0.5, - colorbar_y=0.88, - colorbar_x=0.19, - brain_scale=2.1, - fthresh=0, - caption="Corpus Callosum thickness (mm)", - caption_y=0.85, - caption_x=0.17, - caption_scale=0.5, - ) + try: + with suppress_stdout(): + snap1( + fssurf_file, + overlaypath=overlay_file, + view=None, + viewmat=self.__create_cc_viewmat(), + width=3 * 500, + height=3 * 300, + outpath=output_path, + ambient=0.6, + colorbar_scale=0.5, + colorbar_y=0.88, + colorbar_x=0.19, + brain_scale=2.1, + fthresh=0, + caption="Corpus Callosum thickness (mm)", + caption_y=0.85, + caption_x=0.17, + caption_scale=0.5, + ) + except Exception as e: + raise e from None if fssurf_file and hasattr(fssurf_file, "close"): fssurf_file.close() @@ -744,32 +633,38 @@ def __make_parent_folder(filename: Path | str) -> None: def to_fs_coordinates( self, - vox2ras_tkr: np.ndarray, - ) -> None: + lr_offset: float, + ) -> "CCMesh": """Convert mesh coordinates to FreeSurfer coordinate system. Parameters ---------- - vox2ras_tkr : np.ndarray - 4x4 voxel to RAS tkr-space transformation matrix. + lr_offset : float + Voxel offset to apply before transformation, this should be often `FSAVERAGE_MIDDLE / vox_size_in_lr`. + + Returns + ------- + CCMesh + A CCMesh object with vertices reoriented to FreeSurfer coordinates. Notes ----- - Mesh coordinates seem to be in ASR (Anterior-Superior-Right) orientation, with the coordinate system origin on - *the* midslice. - The function performs the following: - 1. Convert from mesh coordinates (LSA and voxel coordinates) to fsaverage voxel coordinates (LIA, origin). - a. Convert coordinates from ASR to LSA orientation. - b. Convert to voxel coordinates using voxel size. - c. Center LR coordinates and flips SI coordinates. - 2. Apply vox2ras_tkr transformation to get final coordinates. + Mesh coordinates are in ASR (Anterior-Superior-Right) orientation, with the coordinate system origin on + *the* midslice. The function transforms from midslice ASR to LIA vox coordinates. """ + from copy import copy + new_object = copy(self) + + asrvox_midslice2orig_vox2vox = np.eye(4) + # to LSA + asrvox_midslice2orig_vox2vox[:, [0, 2]] = asrvox_midslice2orig_vox2vox[:, [2, 0]] + # center LR + asrvox_midslice2orig_vox2vox[0, 3] = lr_offset + # flip SI + asrvox_midslice2orig_vox2vox[:, 1] *= -1 - # to voxel coordinates - v_vox = self.v.copy() - # to LSA - v_vox = v_vox[:, [2, 1, 0]] + # new_object.v = new_object.v[:, [2, 1, 0]] # to voxel # FIXME: why are the vertex positions multiplied by voxel size here? # removed => for center LR, now dividing by resolution => convert fsaverage middle from mm to vox @@ -777,37 +672,41 @@ def to_fs_coordinates( # all other operations are independent of order of operations (distributive) # v_vox /= vox_size[0] # center LR - v_vox[:, 0] += FSAVERAGE_MIDDLE / self.resolution + # new_object.v[:, 0] += FSAVERAGE_MIDDLE / self.resolution # flip SI - v_vox[:, 1] = -v_vox[:, 1] + # new_object.v[:, 1] = -new_object.v[:, 1] #v_vox_test = np.round(v_vox).astype(int) - ## write volume for debugging - # contour_img = np.zeros(orig.shape) - # for i in range(v_vox_test.shape[0]): - # contour_img[v_vox_test[i, 0], v_vox_test[i, 1], v_vox_test[i, 2]] = 1 # tkrRAS = Torig*[C R S 1]' - # Torig: mri_info --vox2ras-tkr orig.mgz + # Torig: mri_info --vox2ras-tkr orig.mgz # https://surfer.nmr.mgh.harvard.edu/fswiki/CoordinateSystems - self.v = (vox2ras_tkr @ np.concatenate([v_vox, np.ones((self.v.shape[0], 1))], axis=1).T).T[:, :3] - # FIXME: why are the vertex positions multiplied by voxel size here? - # self.v = self.v * vox_size[0] - def write_fssurf(self, filename: Path | str) -> None: - """Write the mesh to a FreeSurfer surface file. + v_vox = np.concatenate([self.v, np.ones((self.v.shape[0], 1))], axis=1) + new_object.v = (v_vox @ asrvox_midslice2orig_vox2vox.T)[:, :3] + # new_object.v = (vox2ras_tkr @ np.concatenate([self.v, np.ones((self.v.shape[0], 1))], axis=1).T).T[:, :3] + return new_object + + def write_fssurf(self, filename: Path | str, image: str | object | None = None) -> None: + """Save as Freesurfer Surface Geometry file (wrap Nibabel). Parameters ---------- - filename : Path, str - Path where to save the FreeSurfer surface file. + filename : str + Filename to save to. + image : str, object, None + Path to image or nibabel image object. If specified, the vertices + are assumed to be in voxel coordinates and are converted + to surface RAS (tkr) coordinates before saving. + The expected order of coordinates is (x, y, z) matching + the image voxel indices. Notes ----- - Creates parent directory if needed before writing the file. + Also creates parent directory if needed before writing the file. """ self.__make_parent_folder(filename) - return super().write_fssurf(filename) + return super().write_fssurf(filename, image=image) def write_morph_data(self, filename: Path | str) -> None: """Write the thickness values as a FreeSurfer overlay file. @@ -823,3 +722,123 @@ def write_morph_data(self, filename: Path | str) -> None: """ self.__make_parent_folder(filename) return nib.freesurfer.write_morph_data(filename, self.mesh_vertex_colors) + + @classmethod + def from_contours( + cls: Self, + contours: list[CCContour], + lr_center: float = 0, + closed: bool = False, + smooth: int = 0, + ) -> Self: + """Create a surface mesh by triangulating between consecutive contours. + + Parameters + ---------- + contours : list[CCContour] + List of CCContour objects to create mesh from. + lr_center : float, default=0 + Center position in the left-right axis. + closed : bool, default=False + Whether to create a closed mesh by adding caps. + smooth : int, default=0 + Number of smoothing iterations to apply. + + Returns + ------- + CCMesh + The joined CCMesh object. + + Raises + ------ + Warning + If no valid contours are found. + + Notes + ----- + The function: + 1. Filters out None contours. + 2. Calculates z-coordinates for each slice. + 3. Creates triangles between adjacent contours. + 4. Optionally: + - Creates caps at both ends. + - Applies smoothing. + - Colors caps based on thickness values. + + """ + + # Check that all contours have the same resolution + resolution = contours[0].resolution + for idx, contour in enumerate(contours[1:], start=1): + if not np.isclose(contour.resolution, resolution): + raise ValueError( + f"All contours must have the same resolution. " + f"Expected {resolution}, but contour at index {idx} has {contour.resolution}." + ) + + # Calculate z coordinates for each slice + z_coordinates = (np.arange(len(contours)) - len(contours) // 2) * contours[0].resolution + lr_center + + # Build vertices list with z-coordinates + vertices = [] + faces = [] + vertex_start_indices = [] # Track starting index for each contour + current_index = 0 + + for i, contour in enumerate(contours): + vertex_start_indices.append(current_index) + vertices.append(np.hstack([contour.contour, np.full((len(contour.contour), 1), z_coordinates[i])])) + + # Check if there's a next valid contour to connect to + if i + 1 < len(contours): + contour2 = contours[i + 1] + faces_between = make_triangles_between_contours(contour.contour, contour2.contour) + faces.append(faces_between + current_index) + + current_index += len(contour.contour) + + vertex_values = np.concatenate([contour.thickness_values for contour in contours]) + + if smooth > 0: + tmp_mesh = CCMesh(vertices, faces, vertex_values=vertex_values) + tmp_mesh.smooth_(smooth) + vertices = tmp_mesh.v + faces = tmp_mesh.t + vertex_values = tmp_mesh.mesh_vertex_colors + + if closed: + # Close the mesh by creating caps on both ends + # Left cap (first slice) - use counterclockwise orientation + left_side_points, left_side_trias = make_mesh_from_contour(vertices[: vertex_start_indices[1]][..., :2]) + left_side_points = np.hstack([left_side_points, np.full((len(left_side_points), 1), z_coordinates[0])]) + + # Right cap (last slice) - reverse points for proper orientation + right_side_points, right_side_trias = make_mesh_from_contour(vertices[vertex_start_indices[-1]:][..., :2]) + right_side_points = np.hstack([right_side_points, np.full((len(right_side_points), 1), z_coordinates[-1])]) + + #FIXME: Can we remove this if-statement? + color_sides = True + if color_sides: + left_side_points, left_side_trias, left_side_colors = _create_cap( + left_side_points, left_side_trias, contours[0] + ) + right_side_points, right_side_trias, right_side_colors = _create_cap( + right_side_points, right_side_trias, contours[-1] + ) + # reverse right side trias + right_side_trias = right_side_trias[:, ::-1] + else: + left_side_colors, right_side_colors = [], [] + + left_side_trias = left_side_trias + current_index + current_index += len(left_side_points) + + right_side_trias = right_side_trias + current_index + current_index += len(right_side_points) + + # FIXME: should this not be a concatenate statements? + vertices = [vertices, left_side_points, right_side_points] + faces = [faces, left_side_trias, right_side_trias] + vertex_values = [vertex_values, left_side_colors, right_side_colors] + + return cls(vertices, faces, vertex_values=vertex_values, resolution=resolution) diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index 348f515e..da209195 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -23,7 +23,7 @@ from CorpusCallosum.data.constants import CC_LABEL, FSAVERAGE_MIDDLE, SUBSEGMENT_LABELS from CorpusCallosum.shape.contour import CCContour from CorpusCallosum.shape.endpoint_heuristic import get_endpoints -from CorpusCallosum.shape.mesh import create_CC_mesh_from_contours +from CorpusCallosum.shape.mesh import CCMesh from CorpusCallosum.shape.metrics import calculate_cc_index from CorpusCallosum.shape.subsegment_contour import ( ContourList, @@ -36,8 +36,8 @@ from CorpusCallosum.shape.thickness import cc_thickness, convert_to_ras from CorpusCallosum.utils.types import CCMeasuresDict, ContourThickness, Points2dType, SliceSelection, SubdivisionMethod from CorpusCallosum.utils.visualization import plot_contours -from FastSurferCNN.utils import AffineMatrix4x4, Image3d, ScalarType, Shape2d, Shape3d, Vector2d -from FastSurferCNN.utils.common import SubjectDirectory, suppress_stdout, update_docstring +from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Mask2d, ScalarType, Shape2d, Shape3d, Vector2d +from FastSurferCNN.utils.common import SubjectDirectory, update_docstring from FastSurferCNN.utils.parallel import process_executor, thread_executor logger = logging.get_logger(__name__) @@ -73,7 +73,7 @@ def create_sag_slice_vox2vox(slice_idx: int, fsaverage_middle: float) -> AffineM @update_docstring(SubdivisionMethod=str(get_args(SubdivisionMethod))[1:-1]) def recon_cc_surf_measures_multi( - segmentation: np.ndarray[Shape3d, np.dtype[int]], + segmentation: np.ndarray[Shape3d, np.dtype[np.int_]], slice_selection: SliceSelection, fsavg_vox2ras: AffineMatrix4x4, midslices: Image3d, @@ -85,7 +85,6 @@ def recon_cc_surf_measures_multi( contour_smoothing: int, subject_dir: SubjectDirectory, vox_size: tuple[float, float, float], - vox2ras_tkr: AffineMatrix4x4 | None = None, ) -> tuple[list[CCMeasuresDict], list[concurrent.futures.Future]]: """Surface reconstruction and metrics computation of corpus callosum slices based on selection mode. @@ -115,8 +114,6 @@ def recon_cc_surf_measures_multi( The SubjectDirectory object managing file names in the subject directory. vox_size : 3-tuple of floats LIA-oriented voxel size in millimeters (x, y, z). - vox2ras_tkr : np.ndarray, optional - Voxel to RAS tkr-space transformation matrix. Returns ------- @@ -166,6 +163,7 @@ def recon_cc_surf_measures_multi( per_slice_recon = process_executor().map(_each_slice, slices_to_recon, per_slice_vox2ras, chunksize=1) cc_contours = [] + run = thread_executor().submit for i, (slice_idx, _results) in enumerate(zip(slices_to_recon, per_slice_recon, strict=True)): progress = f" ({i+1} of {num_slices})" if num_slices > 1 else "" logger.info(f"Calculating CC measurements for slice {slice_idx+1}{progress}") @@ -174,7 +172,7 @@ def recon_cc_surf_measures_multi( contour_in_as_space_and_thickness: ContourThickness = _results[1] endpoint_idxs: tuple[int, int] = _results[2] contour_in_as_space: Points2dType = contour_in_as_space_and_thickness[:, :2] - thickness_values: np.ndarray[tuple[int], np.dtype[float]] = contour_in_as_space_and_thickness[:, 2] + thickness_values: np.ndarray[tuple[int], np.dtype[np.float_]] = contour_in_as_space_and_thickness[:, 2] cc_contours.append(CCContour(contour_in_as_space, thickness_values, endpoint_idxs, resolution=vox_size[0])) if cc_measures is None: @@ -185,7 +183,7 @@ def recon_cc_surf_measures_multi( is_debug = logger.getEffectiveLevel() <= logging.DEBUG is_midslice = slice_idx == num_slices // 2 if subject_dir.has_attribute("cc_qc_image") and (is_debug or is_midslice): - qc_imgs: list[Path] = (subject_dir.filename_by_attribute("cc_qc_image"),) + qc_imgs: list[Path] = [subject_dir.filename_by_attribute("cc_qc_image")] if is_debug: qc_slice_img = qc_imgs[0].with_suffix(f".slice_{slice_idx}.png") qc_imgs = (qc_imgs if is_midslice else []) + [qc_slice_img] @@ -194,7 +192,7 @@ def recon_cc_surf_measures_multi( current_slice_in_volume = midslices.shape[0] // 2 - num_slices // 2 + slice_idx # Create visualization for this slice io_futures.append( - thread_executor().submit( + run( plot_contours, transformed=midslices[current_slice_in_volume:current_slice_in_volume+1], split_contours=cc_measures["split_contours"], @@ -215,7 +213,7 @@ def recon_cc_surf_measures_multi( template_dir.mkdir(parents=True, exist_ok=True) logger.info("Saving template files (contours.txt, thickness_values.txt, " f"thickness_measurement_points.txt) to {template_dir}") - run = thread_executor().submit + run = run for j in range(len(cc_contours)): # FIXME: check, if this is fixed (thickness values not nan == 200) # this does not seem to be thread-safe, do not parallelize! @@ -225,34 +223,34 @@ def recon_cc_surf_measures_multi( mesh_outputs = ("html", "mesh", "thickness_overlay", "surf", "thickness_image") if len(cc_contours) > 1 and any(subject_dir.has_attribute(f"cc_{n}") for n in mesh_outputs): _cc_contours = thread_executor().map(_resample_thickness, cc_contours) - cc_mesh = create_CC_mesh_from_contours(list(cc_contours), smooth=1) + cc_mesh = CCMesh.from_contours(list(_cc_contours), smooth=1) if subject_dir.has_attribute("cc_html"): logger.info(f"Saving CC 3D visualization to {subject_dir.filename_by_attribute('cc_html')}") - io_futures.append(thread_executor().submit( - cc_mesh.plot_mesh,output_path=subject_dir.filename_by_attribute("cc_html"))) + io_futures.append(run( + cc_mesh.plot_mesh,output_path=subject_dir.filename_by_attribute("cc_html")), + ) if subject_dir.has_attribute("cc_mesh"): vtk_file_path = subject_dir.filename_by_attribute("cc_mesh") logger.info(f"Saving vtk file to {vtk_file_path}") - io_futures.append(thread_executor().submit(cc_mesh.write_vtk, vtk_file_path)) + io_futures.append(run(cc_mesh.write_vtk, vtk_file_path)) - cc_mesh.to_fs_coordinates(vox2ras_tkr=vox2ras_tkr) + # the mesh is generated in upright coordinates, so we need to also transform to orig coordinates + cc_mesh = cc_mesh.to_fs_coordinates(lr_offset=FSAVERAGE_MIDDLE / vox_size[0]) if subject_dir.has_attribute("cc_thickness_overlay"): overlay_file_path = subject_dir.filename_by_attribute("cc_thickness_overlay") logger.info(f"Saving overlay file to {overlay_file_path}") - io_futures.append(thread_executor().submit(cc_mesh.write_morph_data, overlay_file_path)) + io_futures.append(run(cc_mesh.write_morph_data, overlay_file_path)) if subject_dir.has_attribute("cc_surf"): surf_file_path = subject_dir.filename_by_attribute("cc_surf") logger.info(f"Saving surf file to {surf_file_path}") - io_futures.append(thread_executor().submit(cc_mesh.write_fssurf, surf_file_path)) + io_futures.append(run(cc_mesh.write_fssurf, str(surf_file_path), str(subject_dir.conf_name))) if subject_dir.has_attribute("cc_thickness_image"): thickness_image_path = subject_dir.filename_by_attribute("cc_thickness_image") logger.info(f"Saving thickness image to {thickness_image_path}") - # note: suppress_stdout is not thread-safe! But it works fine, if only one thread uses it... - with suppress_stdout(): - cc_mesh.snap_cc_picture(thickness_image_path) + cc_mesh.snap_cc_picture(thickness_image_path, subject_dir.conf_name) if not slice_cc_measures: logger.error("Error: No valid slices were found for postprocessing") @@ -269,7 +267,7 @@ def _resample_thickness(contour: CCContour) -> CCContour: def recon_cc_surf_measure( - segmentation: np.ndarray[Shape2d, np.dtype[int]], + segmentation: np.ndarray[Shape2d, np.dtype[np.int_]], slice_idx: int, affine: AffineMatrix4x4, ac_coords: Vector2d, @@ -328,7 +326,7 @@ def recon_cc_surf_measure( 4. Computes shape metrics and subdivisions. 5. Generates visualization data. """ - cc_mask_slice: np.ndarray[tuple[int, int], np.dtype[bool]] = np.equal(segmentation[slice_idx], CC_LABEL) + cc_mask_slice: Mask2d = np.equal(segmentation[slice_idx], CC_LABEL) if not np.any(cc_mask_slice): raise ValueError(f"No CC found in slice {slice_idx}") contour, endpoint_idxs = get_endpoints( @@ -415,7 +413,7 @@ def vectorized_line_test( coords_y: np.ndarray[tuple[int], np.dtype[ScalarType]], line_start: Vector2d, line_end: Vector2d, -) -> np.ndarray[tuple[int], np.dtype[bool]]: +) -> np.ndarray[tuple[int], np.dtype[np.bool_]]: """Vectorized version of point_relative_to_line for arrays of points. Parameters @@ -507,8 +505,8 @@ def get_unique_contour_points(split_contours: ContourList) -> list[Points2dType] def make_subdivision_mask( slice_shape: Shape2d, split_contours: ContourList, - vox_size: tuple[float, float, float], -) -> np.ndarray[Shape2d, np.dtype[int]]: + vox_size: tuple[float, float], +) -> np.ndarray[Shape2d, np.dtype[np.int_]]: """Create a mask for subdividing the corpus callosum based on split contours. Parameters diff --git a/CorpusCallosum/shape/subsegment_contour.py b/CorpusCallosum/shape/subsegment_contour.py index a98037c8..b730cb51 100644 --- a/CorpusCallosum/shape/subsegment_contour.py +++ b/CorpusCallosum/shape/subsegment_contour.py @@ -107,7 +107,7 @@ def calc_subsegment_areas(split_contours: ContourList) -> np.ndarray[tuple[int], def subsegment_midline_orthogonal( midline: Points2dType, - area_weights: np.ndarray[tuple[int], np.dtype[float]], + area_weights: np.ndarray[tuple[int], np.dtype[np.float_]], contour: Polygon2dType, plot: bool = True, ax=None, @@ -368,7 +368,7 @@ def subsegment_midline_orthogonal( def hampel_subdivide_contour(contour: Polygon2dType, num_rays: int, plot: bool = False, ax=None) \ - -> tuple[np.ndarray[tuple[int], np.dtype[float]], ContourList]: + -> tuple[np.ndarray[tuple[int], np.dtype[np.float_]], ContourList]: # FIXME: needs docstring # Find the extreme points in the x-direction min_x_index = np.argmin(contour[0]) @@ -513,7 +513,7 @@ def subdivide_contour( plot_transform: Callable | None = None, oriented: bool = False, hline_anchor: np.ndarray | None = None -) -> tuple[np.ndarray[tuple[int], np.dtype[float]], ContourList]: +) -> tuple[np.ndarray[tuple[int], np.dtype[np.float_]], ContourList]: """Subdivide contour based on area weights using vertical lines. Divides the contour into segments by drawing vertical lines at positions diff --git a/CorpusCallosum/shape/thickness.py b/CorpusCallosum/shape/thickness.py index 10fd9c8f..0303d1fe 100644 --- a/CorpusCallosum/shape/thickness.py +++ b/CorpusCallosum/shape/thickness.py @@ -23,7 +23,7 @@ from FastSurferCNN.utils.common import suppress_stdout -def compute_curvature(path: Points2dType) -> np.ndarray[tuple[int], np.dtype[float]]: +def compute_curvature(path: Points2dType) -> np.ndarray[tuple[int], np.dtype[np.float_]]: """Compute curvature by computing edge angles. Parameters diff --git a/CorpusCallosum/transforms/segmentation.py b/CorpusCallosum/transforms/segmentation.py index 9d2c7268..2b54b450 100644 --- a/CorpusCallosum/transforms/segmentation.py +++ b/CorpusCallosum/transforms/segmentation.py @@ -88,7 +88,7 @@ def __call__(self, data: dict) -> dict: ac_pc_bottomleft = np.min(ac_pc, axis=0).astype(int) ac_pc_topright = np.max(ac_pc, axis=0).astype(int) - VoxPadType = np.ndarray[tuple[Literal[2]], np.dtype[int]] + VoxPadType = np.ndarray[tuple[Literal[2]], np.dtype[np.int_]] voxel_padding: VoxPadType = np.round(self.padding_mm / d["res"]).astype(int) crop_left = ac_pc_bottomleft[1] - int(voxel_padding[0] * 1.5) + random_translate[0] diff --git a/CorpusCallosum/utils/mapping_helpers.py b/CorpusCallosum/utils/mapping_helpers.py index 8e103dee..8dd72e5d 100644 --- a/CorpusCallosum/utils/mapping_helpers.py +++ b/CorpusCallosum/utils/mapping_helpers.py @@ -13,6 +13,7 @@ AffineMatrix4x4, Image2d, Image3d, + Image4d, RotationMatrix3x3, Shape3d, Vector2d, @@ -296,30 +297,52 @@ def make_affine(simpleITKImage: sitk.Image) -> AffineMatrix4x4: return affine +@overload def map_softlabels_to_orig( - cc_fn_softlabels: Image3d, - orig_fsaverage_vox2vox: AffineMatrix4x4, + cc_fn_softlabels: Image4d, + orig: nibabelImage, + orig2slab_vox2vox: AffineMatrix4x4, + cc_subseg_midslice: None = None, + orig2midslice_vox2vox: None = None, + orig_space_segmentation_path: str | Path | None = None, +) -> np.ndarray[Shape3d, np.dtype[np.int_]]: ... + + +@overload +def map_softlabels_to_orig( + cc_fn_softlabels: Image4d, orig: nibabelImage, + orig2slab_vox2vox: AffineMatrix4x4, + cc_subseg_midslice: Image2d, + orig2midslice_vox2vox: AffineMatrix4x4, orig_space_segmentation_path: str | Path | None = None, - fsaverage_middle: int = 128, - cc_subseg_midslice: Image2d | None = None -) -> np.ndarray[Shape3d, np.dtype[int]]: +) -> np.ndarray[Shape3d, np.dtype[np.int_]]: ... + + +def map_softlabels_to_orig( + cc_fn_softlabels: Image4d, + orig: nibabelImage, + orig2slab_vox2vox: AffineMatrix4x4, + cc_subseg_midslice: Image2d | None = None, + orig2midslice_vox2vox: AffineMatrix4x4 | None = None, + orig_space_segmentation_path: str | Path | None = None, +) -> np.ndarray[Shape3d, np.dtype[np.int_]]: """Map soft labels back to original image space and apply post-processing. Parameters ---------- cc_fn_softlabels : np.ndarray - Soft label predictions. - orig_fsaverage_vox2vox : AffineMatrix4x4 - Original to fsaverage space transformation. + Soft label predictions of shape (H, W, D, C=3). orig : nibabelImage Original image. + orig2slab_vox2vox : AffineMatrix4x4 + The vox2vox transformation matrix from orig to the slab. + cc_subseg_midslice : np.ndarray, optional + Mask for subdividing regions of shape (H, D) (only paired with orig2midslice_vox2vox). + orig2midslice_vox2vox : AffineMatrix4x4, optional + The vox2vox transformation matrix from orig to the midslice (only paired with cc_subseg_midslice). orig_space_segmentation_path : str or Path, optional Path to save segmentation in original space. - fsaverage_middle : int, default=128 - Middle slice index in fsaverage space. - cc_subseg_midslice : np.ndarray, optional - Mask for subdividing regions. Returns ------- @@ -333,37 +356,34 @@ def map_softlabels_to_orig( 2. Transform CC subsegmentation from midslice to orig and paint into segmentation if `cc_subseg_midslice` is passed. 4. Saves result to `orig_space_segmentation_path` if passed. """ - slices_to_analyze = cc_fn_softlabels.shape[0] # map softlabels to original image - slab2fsaverage_vox2vox = np.eye(4) - slab2fsaverage_vox2vox[0, 3] = -(fsaverage_middle - slices_to_analyze // 2) - slab2orig_vox2vox = orig_fsaverage_vox2vox @ slab2fsaverage_vox2vox - - def _map_softlabel_to_orig(i: int, data: Image3d) -> Image3d: - return affine_transform(data, slab2orig_vox2vox, output_shape=orig.shape, order=1, cval=float(i == 0)) - - _softlabels = np.moveaxis(cc_fn_softlabels, -1, 0) - softlabels_transformed = thread_executor().map(_map_softlabel_to_orig, *zip(*enumerate(_softlabels), strict=True)) - - softlabels_orig_space = np.stack(list(softlabels_transformed), axis=-1) - seg_orig_space = np.argmax(softlabels_orig_space, axis=-1) - # map to freesurfer labels - seg_lut = np.asarray([0, CC_LABEL, FORNIX_LABEL]) - seg_orig_space = seg_lut[seg_orig_space] - - if cc_subseg_midslice is not None: - # map subdivision mask to orig space - midslice2fsaverage_vox2vox = np.eye(4) - midslice2fsaverage_vox2vox[0, 3] = -fsaverage_middle - cc_subseg_orig_space = affine_transform( + def _map_softlabel_to_orig(data: Image3d, fill: int) -> Image3d: + # # Note: affine_transforms requires the inverse of the intended direction -> orig2slab + return affine_transform(data, orig2slab_vox2vox, output_shape=orig.shape, order=1, cval=fill) + + if cc_subseg_midslice is not None and orig2midslice_vox2vox is not None: + # map subdivision mask to orig space, this will also expand the labels into left-right direction + cc_subseg_orig_space_fut = thread_executor().submit( + affine_transform, cc_subseg_midslice[None], - orig_fsaverage_vox2vox, + orig2midslice_vox2vox, # Note: affine_transforms requires the inverse of the intended direction output_shape=orig.shape, order=0, mode="nearest", ) + else: + cc_subseg_orig_space_fut = None - seg_orig_space = np.where(seg_orig_space == CC_LABEL, cc_subseg_orig_space, seg_orig_space) + _softlabels = np.moveaxis(cc_fn_softlabels, -1, 0) + softlabels_iter = thread_executor().map(_map_softlabel_to_orig, _softlabels, [1., 0., 0.]) + softlabels_orig_space = np.stack(list(softlabels_iter), axis=-1) + # map to freesurfer labels + seg_lut = np.asarray([0, CC_LABEL, FORNIX_LABEL]) + seg_orig_space = seg_lut[np.argmax(softlabels_orig_space, axis=-1)] + + if cc_subseg_orig_space_fut is not None: + # replace CC_LABEL by subsegmentation labels + seg_orig_space = np.where(seg_orig_space == CC_LABEL, cc_subseg_orig_space_fut.result(), seg_orig_space) if orig_space_segmentation_path is not None: logger.info(f"Saving segmentation in original space to {orig_space_segmentation_path}") @@ -371,5 +391,4 @@ def _map_softlabel_to_orig(i: int, data: Image3d) -> Image3d: nib.MGHImage(seg_orig_space, orig.affine, orig.header), orig_space_segmentation_path, ) - return seg_orig_space diff --git a/FastSurferCNN/utils/__init__.py b/FastSurferCNN/utils/__init__.py index 24b5163d..005a63e4 100644 --- a/FastSurferCNN/utils/__init__.py +++ b/FastSurferCNN/utils/__init__.py @@ -22,7 +22,6 @@ "load_config", "logging", "lr_scheduler", - "LTADict", "mapper", "Mask2d", "Mask3d", diff --git a/env/fastsurfer.yml b/env/fastsurfer.yml index 7c73f303..a3c5231d 100644 --- a/env/fastsurfer.yml +++ b/env/fastsurfer.yml @@ -5,7 +5,7 @@ channels: dependencies: - h5py==3.12.1 -- lapy==1.2.0 +- lapy==1.4.0 - matplotlib==3.10.1 - monai==1.4.0 - nibabel==5.3.2 diff --git a/pyproject.toml b/pyproject.toml index 6c116e0c..46dd1c40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ classifiers = [ ] dependencies = [ 'h5py>=3.7', - 'lapy>=1.1.0', + "lapy>=1.4.0", 'matplotlib>=3.7.1', 'nibabel>=5.1.0', 'numpy>=1.25,<2', From bbb66c7a1036b6b5a3852b59149ae70d69b62e6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Tue, 16 Dec 2025 19:00:46 +0100 Subject: [PATCH 51/68] Fix surface output --- CorpusCallosum/fastsurfer_cc.py | 1 + CorpusCallosum/shape/mesh.py | 2 +- CorpusCallosum/shape/postprocessing.py | 36 ++++++++++++++++---------- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index a927bf48..dfbeaa0e 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -805,6 +805,7 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: logger.info(f"Processing slices with selection mode: {slice_selection}") slice_results, slice_io_futures = recon_cc_surf_measures_multi( segmentation=cc_fn_seg_labels, + upright_affine_header=(fsavg_vox2ras, orig.header), slice_selection=slice_selection, fsavg_vox2ras=fsavg_vox2ras, midslices=midslices, diff --git a/CorpusCallosum/shape/mesh.py b/CorpusCallosum/shape/mesh.py index 74af3c63..ca4325f0 100644 --- a/CorpusCallosum/shape/mesh.py +++ b/CorpusCallosum/shape/mesh.py @@ -520,7 +520,7 @@ def snap_cc_picture( overlay_file : Path, str, optional Path to a FreeSurfer overlay file to use for the snapshot. If None, the mesh is saved to a temporary file. - ref_image : Path, str, optional + ref_image : Path, str, nibabelImage, optional Path to reference image to use for tkr creation. If None, ignores the file for saving. Raises diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index da209195..440c6787 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -36,7 +36,8 @@ from CorpusCallosum.shape.thickness import cc_thickness, convert_to_ras from CorpusCallosum.utils.types import CCMeasuresDict, ContourThickness, Points2dType, SliceSelection, SubdivisionMethod from CorpusCallosum.utils.visualization import plot_contours -from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Mask2d, ScalarType, Shape2d, Shape3d, Vector2d +from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Mask2d, ScalarType, Shape2d, Shape3d, Vector2d, nibabelImage, \ + nibabelHeader from FastSurferCNN.utils.common import SubjectDirectory, update_docstring from FastSurferCNN.utils.parallel import process_executor, thread_executor @@ -74,6 +75,7 @@ def create_sag_slice_vox2vox(slice_idx: int, fsaverage_middle: float) -> AffineM @update_docstring(SubdivisionMethod=str(get_args(SubdivisionMethod))[1:-1]) def recon_cc_surf_measures_multi( segmentation: np.ndarray[Shape3d, np.dtype[np.int_]], + upright_affine_header: tuple[AffineMatrix4x4, nibabelHeader], slice_selection: SliceSelection, fsavg_vox2ras: AffineMatrix4x4, midslices: Image3d, @@ -92,6 +94,8 @@ def recon_cc_surf_measures_multi( ---------- segmentation : np.ndarray 3D segmentation array. + upright_affine_header : tuple[AffineMatrix4x4, nibabelHeader] + A tuple of the vox2ras matrix and the header of the upright image. slice_selection : str Which slices to process ('middle', 'all', or slice number). fsavg_vox2ras : np.ndarray @@ -164,6 +168,7 @@ def recon_cc_surf_measures_multi( cc_contours = [] run = thread_executor().submit + wants_output = subject_dir.has_attribute for i, (slice_idx, _results) in enumerate(zip(slices_to_recon, per_slice_recon, strict=True)): progress = f" ({i+1} of {num_slices})" if num_slices > 1 else "" logger.info(f"Calculating CC measurements for slice {slice_idx+1}{progress}") @@ -182,7 +187,7 @@ def recon_cc_surf_measures_multi( slice_cc_measures.append(cc_measures) is_debug = logger.getEffectiveLevel() <= logging.DEBUG is_midslice = slice_idx == num_slices // 2 - if subject_dir.has_attribute("cc_qc_image") and (is_debug or is_midslice): + if wants_output("cc_qc_image") and (is_debug or is_midslice): qc_imgs: list[Path] = [subject_dir.filename_by_attribute("cc_qc_image")] if is_debug: qc_slice_img = qc_imgs[0].with_suffix(f".slice_{slice_idx}.png") @@ -207,7 +212,7 @@ def recon_cc_surf_measures_multi( ) - if subject_dir.has_attribute("save_template_dir"): + if wants_output("save_template_dir"): template_dir = subject_dir.filename_by_attribute("save_template_dir") # ensure directory exists template_dir.mkdir(parents=True, exist_ok=True) @@ -221,36 +226,41 @@ def recon_cc_surf_measures_multi( io_futures.append(run(cc_contours[j].save_thickness_values, template_dir / f"thickness_values_{j}.txt")) mesh_outputs = ("html", "mesh", "thickness_overlay", "surf", "thickness_image") - if len(cc_contours) > 1 and any(subject_dir.has_attribute(f"cc_{n}") for n in mesh_outputs): + if len(cc_contours) > 1 and any(wants_output(f"cc_{n}") for n in mesh_outputs): _cc_contours = thread_executor().map(_resample_thickness, cc_contours) cc_mesh = CCMesh.from_contours(list(_cc_contours), smooth=1) - if subject_dir.has_attribute("cc_html"): + if wants_output("cc_html"): logger.info(f"Saving CC 3D visualization to {subject_dir.filename_by_attribute('cc_html')}") io_futures.append(run( cc_mesh.plot_mesh,output_path=subject_dir.filename_by_attribute("cc_html")), ) - if subject_dir.has_attribute("cc_mesh"): + if wants_output("cc_mesh"): vtk_file_path = subject_dir.filename_by_attribute("cc_mesh") logger.info(f"Saving vtk file to {vtk_file_path}") io_futures.append(run(cc_mesh.write_vtk, vtk_file_path)) - # the mesh is generated in upright coordinates, so we need to also transform to orig coordinates - cc_mesh = cc_mesh.to_fs_coordinates(lr_offset=FSAVERAGE_MIDDLE / vox_size[0]) - if subject_dir.has_attribute("cc_thickness_overlay"): + if wants_output("cc_thickness_overlay"): overlay_file_path = subject_dir.filename_by_attribute("cc_thickness_overlay") logger.info(f"Saving overlay file to {overlay_file_path}") io_futures.append(run(cc_mesh.write_morph_data, overlay_file_path)) - if subject_dir.has_attribute("cc_surf"): + if any(wants_output(f"cc_{n}") for n in ("thickness_image", "cc_surf")): + import nibabel as nib + upright_img = nib.MGHImage(np.zeros(() * 3, dtype=np.uint8)) + + # the mesh is generated in upright coordinates, so we need to also transform to orig coordinates + cc_mesh = cc_mesh.to_fs_coordinates(lr_offset=FSAVERAGE_MIDDLE / vox_size[0]) + if wants_output("cc_surf"): surf_file_path = subject_dir.filename_by_attribute("cc_surf") logger.info(f"Saving surf file to {surf_file_path}") - io_futures.append(run(cc_mesh.write_fssurf, str(surf_file_path), str(subject_dir.conf_name))) + cc_mesh.write_fssurf(str(surf_file_path), image=upright_img) + # io_futures.append(run(cc_mesh.write_fssurf, str(surf_file_path), image=orig)) - if subject_dir.has_attribute("cc_thickness_image"): + if wants_output("cc_thickness_image"): thickness_image_path = subject_dir.filename_by_attribute("cc_thickness_image") logger.info(f"Saving thickness image to {thickness_image_path}") - cc_mesh.snap_cc_picture(thickness_image_path, subject_dir.conf_name) + cc_mesh.snap_cc_picture(thickness_image_path, ref_image=upright_img) if not slice_cc_measures: logger.error("Error: No valid slices were found for postprocessing") From adfd52360de566a19d3c73e6bdcdc71752f26ddb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Fri, 19 Dec 2025 11:35:00 +0100 Subject: [PATCH 52/68] - Clean up docstrings, typing - change instatiation of CCContour to from_* syntax - update io functions for CCContour - Optimize CCContour - Fix vox2ras_tkr application for surface to save - Add and use alternative implementation for CC endpoint finding (requires no image rotation) - Remove resolution, where it is not needed and build all vox2vox and vox2ras operations on transformation matrices (including surfaces and image to midslab/midslice) --- CorpusCallosum/cc_visualization.py | 14 +- CorpusCallosum/data/fsaverage_cc_template.py | 27 +- CorpusCallosum/data/read_write.py | 21 +- CorpusCallosum/fastsurfer_cc.py | 144 ++++++--- .../segmentation_postprocessing.py | 6 +- CorpusCallosum/shape/contour.py | 274 +++++++++++------ CorpusCallosum/shape/endpoint_heuristic.py | 194 ++++++++---- CorpusCallosum/shape/mesh.py | 109 +++---- CorpusCallosum/shape/postprocessing.py | 282 +++++++++++------- CorpusCallosum/shape/subsegment_contour.py | 51 +--- CorpusCallosum/shape/thickness.py | 103 +------ CorpusCallosum/utils/mapping_helpers.py | 40 +-- 12 files changed, 700 insertions(+), 565 deletions(-) diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index c6911729..d05305c2 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -135,16 +135,18 @@ def load_contours_from_template_dir( num_thickness_values = np.sum(~np.isnan(np.array(thickness_values[1:],dtype=float))) if fsaverage_contour is None: fsaverage_contour = load_fsaverage_cc_template() - # create measurement points (points = 2 x levelpaths) accorindg to number of thickness values + # create measurement points (points = 2 x levelpaths) according to number of thickness values fsaverage_contour.create_levelpaths(num_points=num_thickness_values // 2, update_data=True) current_contour = fsaverage_contour.copy() current_contour.load_thickness_values(thickness_file) else: # this is kinda ugly - maybe we need to overload the constructor to load the contour and thickness values? - current_contour = CCContour(np.empty((0, 2)), np.empty((0,)), resolution=resolution) - current_contour.load_contour(contour_file) - current_contour.load_thickness_values(thickness_file) + # FIXME: The z_position in from_contour is still incorrect, currently all Contours would be "registered" for + # the midslice. + current_contour = CCContour.from_contour_file(contour_file, thickness_file, z_position=0.0) + # current_contour.load_contour(contour_file) + # current_contour.load_thickness_values(thickness_file) current_contour.fill_thickness_values() contours.append(current_contour) @@ -194,6 +196,7 @@ def main( return 0 # 3D visualization + # FIXME: This function would need contours[i].z_position to be properly initialized! cc_mesh = CCMesh.from_contours(contours, smooth=0) plot_kwargs = dict( @@ -205,7 +208,8 @@ def main( cc_mesh.plot_mesh(**plot_kwargs) cc_mesh.plot_mesh(output_path=str(output_dir / "cc_mesh.html"), **plot_kwargs) - cc_mesh = cc_mesh.to_fs_coordinates(lr_offset=FSAVERAGE_MIDDLE / resolution) + #FIXME: needs to be adapted to new interface of CCMesh.to_fs_coordinates / to_vox_coordinates + cc_mesh = cc_mesh.to_vox_coordinates(lr_offset=FSAVERAGE_MIDDLE / resolution) logger.info(f"Writing vtk file to {output_dir / 'cc_mesh.vtk'}") cc_mesh.write_vtk(str(output_dir / "cc_mesh.vtk")) logger.info(f"Writing freesurfer surface file to {output_dir / 'cc_mesh.fssurf'}") diff --git a/CorpusCallosum/data/fsaverage_cc_template.py b/CorpusCallosum/data/fsaverage_cc_template.py index 55d11bc2..f52a849a 100644 --- a/CorpusCallosum/data/fsaverage_cc_template.py +++ b/CorpusCallosum/data/fsaverage_cc_template.py @@ -14,6 +14,7 @@ import os from pathlib import Path +from typing import cast import nibabel as nib import numpy as np @@ -22,8 +23,12 @@ from CorpusCallosum.data import constants from CorpusCallosum.shape.contour import CCContour from CorpusCallosum.shape.postprocessing import recon_cc_surf_measure +from FastSurferCNN.utils import nibabelImage from FastSurferCNN.utils.brainvolstats import mask_in_array +FSAVERAGE_PC_COORDINATE = np.array([131, 99]) +FSAVERAGE_AC_COORDINATE = np.array([135, 130]) + def smooth_contour(contour: tuple[np.ndarray, np.ndarray], window_size: int = 5) -> tuple[np.ndarray, np.ndarray]: """Smooth a contour using a moving average filter. @@ -62,9 +67,7 @@ def smooth_contour(contour: tuple[np.ndarray, np.ndarray], window_size: int = 5) return (x_smoothed, y_smoothed) -def load_fsaverage_cc_template() -> tuple[ - np.ndarray, tuple[np.ndarray, np.ndarray], np.ndarray, np.ndarray, np.ndarray, tuple[int, int] -]: +def load_fsaverage_cc_template() -> CCContour: """Load and process the fsaverage corpus callosum template. This function loads the fsaverage segmentation from FreeSurfer's data directory, @@ -72,8 +75,8 @@ def load_fsaverage_cc_template() -> tuple[ Returns ------- - tuple - Contains: + CCContour + Object with all the contour information including: - contour : tuple[np.ndarray, np.ndarray] : x and y coordinates of the contour points. - anterior_endpoint_idx : np.ndarray : Index of the anterior endpoint. - posterior_endpoint_idx : np.ndarray : Index of the posterior endpoint. @@ -95,13 +98,9 @@ def load_fsaverage_cc_template() -> tuple[ f"FREESURFER_HOME environment variable") from err fsaverage_seg_path = freesurfer_home / 'subjects' / 'fsaverage' / 'mri' / 'aparc+aseg.mgz' - fsaverage_seg = nib.load(fsaverage_seg_path) + fsaverage_seg = cast(nibabelImage, nib.load(fsaverage_seg_path)) segmentation = np.asarray(fsaverage_seg.dataobj) - PC = np.array([131, 99]) - AC = np.array([135, 130]) - - midslice = segmentation.shape[0]//2 +1 cc_mask = mask_in_array(segmentation[midslice], constants.SUBSEGMENT_LABELS) @@ -124,9 +123,9 @@ def load_fsaverage_cc_template() -> tuple[ _, contour_with_thickness, (anterior_endpoint_idx, posterior_endpoint_idx) = recon_cc_surf_measure( segmentation=cc_mask[None], slice_idx=0, - ac_coords=AC, - pc_coords=PC, - affine=fsaverage_seg.affine, + ac_coords_vox=FSAVERAGE_AC_COORDINATE, + pc_coords_vox=FSAVERAGE_PC_COORDINATE, + slice_lia_vox2midslice_ras=fsaverage_seg.affine, num_thickness_points=100, subdivisions=[1/6, 1/2, 2/3, 3/4], subdivision_method="shape", @@ -148,7 +147,7 @@ def load_fsaverage_cc_template() -> tuple[ fsaverage_contour = CCContour(np.array(outside_contour).T, np.zeros(len(outside_contour[0])), endpoint_idxs=(anterior_endpoint_idx, posterior_endpoint_idx), - resolution=1.0) + z_position=0.0) return fsaverage_contour diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py index fcc38c17..673d7d31 100644 --- a/CorpusCallosum/data/read_write.py +++ b/CorpusCallosum/data/read_write.py @@ -20,17 +20,18 @@ from numpy import typing as npt import FastSurferCNN.utils.logging as logging -from FastSurferCNN.utils import AffineMatrix4x4, nibabelImage +from FastSurferCNN.utils import AffineMatrix4x4, RotationMatrix3x3, Vector3d, nibabelImage from FastSurferCNN.utils.parallel import thread_executor +logger = logging.get_logger(__name__) -class FSAverageHeader(TypedDict): - dims: npt.NDArray[int] - delta: npt.NDArray[float] - Mdc: npt.NDArray[float] - Pxyz_c: npt.NDArray[float] -logger = logging.get_logger(__name__) +class MGHHeaderDict(TypedDict): + """A dictionary with the four required fields of a MGH Header""" + dims: Vector3d + delta: Vector3d + Mdc: RotationMatrix3x3 + Pxyz_c: Vector3d def calc_ras_centroids_from_seg(seg_img: nibabelImage, label_ids: list[int] | None = None) \ @@ -154,7 +155,7 @@ def load_fsaverage_affine(affine_path: str | Path) -> npt.NDArray[float]: return affine_matrix -def load_fsaverage_data(data_path: str | Path) -> tuple[AffineMatrix4x4, FSAverageHeader]: +def load_fsaverage_data(data_path: str | Path) -> tuple[AffineMatrix4x4, MGHHeaderDict]: """Load fsaverage affine matrix and header fields from static JSON file. Parameters @@ -206,13 +207,13 @@ def load_fsaverage_data(data_path: str | Path) -> tuple[AffineMatrix4x4, FSAvera # Convert lists back to numpy arrays affine_matrix = np.array(data["affine"]) - header_data = FSAverageHeader( + header_data = MGHHeaderDict( dims=data["header"]["dims"], delta=data["header"]["delta"], Mdc=np.array(data["header"]["Mdc"]), Pxyz_c=np.array(data["header"]["Pxyz_c"]), ) - + # Validate affine matrix shape if affine_matrix.shape != (4, 4): raise ValueError(f"Expected 4x4 affine matrix, got shape {affine_matrix.shape}") diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index dfbeaa0e..d8cbdb80 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -16,7 +16,6 @@ import argparse import json from collections.abc import Iterable -from functools import partial from pathlib import Path from time import perf_counter_ns from typing import Literal, TypeVar, cast @@ -25,6 +24,7 @@ import numpy as np import torch from monai.networks.nets import DenseNet +from nibabel.freesurfer.mghformat import MGHHeader from scipy.ndimage import affine_transform from CorpusCallosum.data.constants import ( @@ -37,7 +37,7 @@ THIRD_VENTRICLE_LABEL, ) from CorpusCallosum.data.read_write import ( - FSAverageHeader, + MGHHeaderDict, calc_ras_centroids_from_seg, convert_numpy_to_json_serializable, load_fsaverage_centroids, @@ -48,8 +48,8 @@ from CorpusCallosum.segmentation import segmentation_postprocessing from CorpusCallosum.shape.postprocessing import ( check_area_changes, - create_sag_slice_vox2vox, make_subdivision_mask, + offset_affine, recon_cc_surf_measures_multi, ) from CorpusCallosum.utils.mapping_helpers import ( @@ -61,7 +61,17 @@ from CorpusCallosum.utils.types import CCMeasuresDict, SliceSelection, SubdivisionMethod from FastSurferCNN.data_loader.conform import conform, is_conform from FastSurferCNN.segstats import HelpFormatter -from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Image4d, Mask3d, Shape3d, Vector2d, logging, nibabelImage +from FastSurferCNN.utils import ( + AffineMatrix4x4, + Image3d, + Image4d, + Mask3d, + Shape3d, + Vector2d, + logging, + nibabelHeader, + nibabelImage, +) from FastSurferCNN.utils.arg_types import path_or_none from FastSurferCNN.utils.common import SubjectDirectory, find_device from FastSurferCNN.utils.lta import write_lta @@ -249,13 +259,15 @@ def _slice_selection(a: str) -> SliceSelection: "--surf", dest="cc_surf", type=path_or_none, - help="Output path for surf file.", + help="Output path for surf file for visualization in freeview, use --save_template_dir and contours.txt to " + "obtain source CC contours.", default=DEFAULT_OUTPUT_PATHS["cc_surf"], ) advanced.add_argument( "--thickness_overlay", type=path_or_none, - help="Output path for corpus callosum thickness overlay file.", + help="Output path for corpus callosum thickness overlay file for visualization in freeview, use " + "--save_template_dir and thickness_values.txt to obtain source CC thickness values.", default=DEFAULT_OUTPUT_PATHS["cc_thickness_overlay"], ) advanced.add_argument( @@ -268,7 +280,8 @@ def _slice_selection(a: str) -> SliceSelection: advanced.add_argument( "--cc_surf_vtk", type=path_or_none, - help=f"Output path for vtk file, showing the CC 3D mesh. Example: {DEFAULT_OUTPUT_PATHS['cc_surf_vtk']}.", + help=f"Output path for vtk file, showing the CC 3D mesh for visualization, use --save_template_dir and " + f"contours.txt to obtain source CC contours. Example: {DEFAULT_OUTPUT_PATHS['cc_surf_vtk']}.", default=None, ) advanced.add_argument( @@ -366,7 +379,7 @@ def options_parse() -> argparse.Namespace: def register_centroids_to_fsavg(aseg_nib: nibabelImage) \ - -> tuple[AffineMatrix4x4, AffineMatrix4x4, AffineMatrix4x4, FSAverageHeader]: + -> tuple[AffineMatrix4x4, AffineMatrix4x4, AffineMatrix4x4, MGHHeaderDict]: """Perform centroid-based registration between subject and fsaverage space. Computes a rigid transformation between the subject's segmentation and fsaverage space @@ -385,7 +398,7 @@ def register_centroids_to_fsavg(aseg_nib: nibabelImage) \ Transformation matrix from original to fsaverage RAS space. fsaverage_hires_vox2ras : AffineMatrix4x4 High-resolution fsaverage affine matrix. - fsaverage_header : FSAverageHeader + fsaverage_header : MGHHeaderDict FSAverage header fields for LTA writing. Notes @@ -411,11 +424,11 @@ def register_centroids_to_fsavg(aseg_nib: nibabelImage) \ aseg2fsaverage_ras2ras: AffineMatrix4x4 = find_rigid(p_mov=ras_centroids_mov.T, p_dst=ras_centroids_dst.T) # make affine that increases resolution to orig resolution - aseg_zooms = list(nib.as_closest_canonical(aseg_nib).header.get_zooms()[:3]) - resolution_trans: AffineMatrix4x4 = np.diagflat([aseg_zooms[0], aseg_zooms[2], aseg_zooms[1], 1]).astype(float) + aseg_zooms_ras = np.asarray(nib.as_closest_canonical(aseg_nib).header.get_zooms()[:3]) + resolution_trans: AffineMatrix4x4 = np.diagflat(np.append(aseg_zooms_ras[[0, 2, 1]], [1])).astype(float) fsaverage_vox2ras, fsavg_header = fsaverage_data_future.result() - fsavg_header["delta"] = np.asarray([aseg_zooms[0], aseg_zooms[2], aseg_zooms[1]]) # vox sizes in lia + fsavg_header["delta"] = aseg_zooms_ras[[0, 2, 1]] # vox sizes in lia # fsavg_hires_vox2ras translation should be 128 always (independent of resolution) fsavg_hires_vox2ras: AffineMatrix4x4 = np.concatenate( [(resolution_trans @ fsaverage_vox2ras)[:, :3], fsaverage_vox2ras[:, 3:4]], @@ -423,6 +436,11 @@ def register_centroids_to_fsavg(aseg_nib: nibabelImage) \ ) fsavg_header["dims"] = np.ceil(fsavg_header["dims"] @ np.linalg.inv(resolution_trans[:3, :3])).astype(int).tolist() + # Correct fsavg_header["Pxyz_c"] by (vox_size - 1) / 2 in all three directions, because Pxyz_c is not actually in + # the center of the image, but in the center of the voxel in increasing voxel index direction, i.e. index 128 for a + # 256 image (where the center would be at 127.5). + fsavg_header["Pxyz_c"] += (aseg_zooms_ras - 1) / 2 @ fsavg_header["Mdc"] + aseg2fsavg_vox2vox: AffineMatrix4x4 = np.linalg.inv(fsavg_hires_vox2ras) @ aseg2fsaverage_ras2ras @ aseg_nib.affine logger.info("Centroid registration successful!") return aseg2fsavg_vox2vox, aseg2fsaverage_ras2ras, fsavg_hires_vox2ras, fsavg_header @@ -456,9 +474,9 @@ def localize_ac_pc( Returns ------- ac_coords : np.ndarray - Coordinates of the anterior commissure. + AC voxel coordinates with shape (2,) containing its [y,x] positions. pc_coords : np.ndarray - Coordinates of the posterior commissure. + PC voxel coordinates with shape (2,) containing its [y,x] positions. """ num_slices_to_analyze = resample_shape[0] resample_shape = (num_slices_to_analyze + 2,) + resample_shape[1:] # 2 for context slices @@ -501,11 +519,11 @@ def segment_cc( Parameters ---------- midslices : np.ndarray - Array of mid-sagittal slices. + Array of mid-sagittal slices in upright space and LIA-orientation. ac_coords : np.ndarray - Anterior commissure coordinates. + AC voxel coordinates with shape (2,) containing its [y,x] positions. pc_coords : np.ndarray - Posterior commissure coordinates. + PC voxel coordinates with shape (2,) containing its [y,x] positions. aseg_nib : nibabelImage Subject's cc_seg_labels image. model_segmentation : torch.nn.Module @@ -514,9 +532,9 @@ def segment_cc( Returns ------- cc_seg_labels : np.ndarray - Binary cc_seg_labels of the corpus callosum. + Binary cc_seg_labels of the corpus callosum in upright space and LIA-orientation. cc_softlabels : np.ndarray - Soft cc_seg_labels probabilities of shape (H, W, D, C=3). + Soft cc_seg_labels probabilities of shape in upright space and LIA-orientation (H, W, D, C=3). """ pre_clean_segmentation, inputs, cc_softlabels = segmentation_inference.run_inference_on_slice( model_segmentation, @@ -548,7 +566,7 @@ def main( num_thickness_points: int = 100, subdivisions: list[float] | None = None, subdivision_method: SubdivisionMethod = "shape", - contour_smoothing: float = 5, + contour_smoothing: int = 5, save_template_dir: str | Path | None = None, device: str | torch.device = "auto", upright_volume: str | Path | None = None, @@ -589,7 +607,7 @@ def main( List of subdivision fractions for CC subsegmentation. subdivision_method : any of "shape", "vertical", "angular", "eigenvector", default="shape" Method for contour subdivision. - contour_smoothing : float, default=5 + contour_smoothing : int, default=5 Gaussian sigma for smoothing during contour detection. save_template_dir : str or Path, optional Directory path where to save contours.txt and thickness_values.txt files. These files can be used to visualize @@ -719,7 +737,7 @@ def main( slices_to_analyze += 1 logger.info( - f"Segmenting {slices_to_analyze} slices (5 mm width at {vox_size[0]} mm resolution, " + f"Segmenting {slices_to_analyze} slices (5 mm width at {vox_size[0]:.3f} mm resolution, " "center around the mid-sagittal plane)" ) @@ -730,39 +748,44 @@ def main( sys.exit(1) logger.info("Performing centroid registration to fsaverage space") - orig2fsavg_vox2vox, orig2fsavg_ras2ras, fsavg_vox2ras, fsavg_header = register_centroids_to_fsavg(aseg_img) + orig2fsavg_vox2vox, orig2fsavg_ras2ras, fsavg_vox2ras, _fsavg_header_dict = register_centroids_to_fsavg(aseg_img) + fsavg_header = init_mgh_header(orig.header, _fsavg_header_dict) # start saving upright volume, this is the image in fsaverage space but not yet oriented via AC-PC if sd.has_attribute("upright_volume"): # upright == fsaverage-aligned + # FIXME: upright currently does not get saved correctly io_futures.append( thread_executor().submit( apply_transform_to_volume, orig, orig2fsavg_vox2vox, - fsavg_vox2ras, + save_vox2ras=fsavg_vox2ras, output_path=sd.filename_by_attribute("upright_volume"), - output_size=fsavg_header["dims"], + output_size=fsavg_header["dims"][:3], ) ) # calculate affine for segmentation volume - affine_x_offset = partial(create_sag_slice_vox2vox, fsaverage_middle=FSAVERAGE_MIDDLE / vox_size[0]) - fsavg2midslab_in_vox2vox: AffineMatrix4x4 = affine_x_offset(slices_to_analyze // 2) - # first, midslice->fsaverage in vox2vox, then vox2ras in fsaverage space - fsaverage_midslab_vox2ras: AffineMatrix4x4 = fsavg_vox2ras @ np.linalg.inv(fsavg2midslab_in_vox2vox) + fsavg2midslice_vox2vox: AffineMatrix4x4 = offset_affine([-FSAVERAGE_MIDDLE / vox_size[0], 0, 0]) + orig2midslice_vox2vox = fsavg2midslice_vox2vox @ orig2fsavg_vox2vox # calculate vox2vox for input resampling volumes def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: - fsavg2midslab = affine_x_offset(slices_to_analyze // 2 + additional_context // 2) - # first, orig->fsaverage in vox2vox, then fsaverage->midslab in vox2vox - return fsavg2midslab @ orig2fsavg_vox2vox + fsavg2midslab = offset_affine([slices_to_analyze // 2 + additional_context // 2, 0, 0]) + # first, orig->fsaverage, then fsaverage->midslab (all in vox2vox) + return fsavg2midslab @ orig2midslice_vox2vox + + # first, midslice->fsaverage in vox2vox, then vox2ras in fsaverage space + fsavg2midslab_vox2vox = offset_affine([slices_to_analyze // 2, 0, 0]) @ fsavg2midslice_vox2vox + fsaverage_midslab_vox2ras: AffineMatrix4x4 = fsavg_vox2ras @ np.linalg.inv(fsavg2midslab_vox2vox) + #### do localization and segmentation inference logger.info("Starting AC/PC localization") target_shape: tuple[int, int, int] = (slices_to_analyze, fsavg_header["dims"][1], fsavg_header["dims"][2]) # predict ac and pc coordinates in upright AS space - ac_coords, pc_coords = localize_ac_pc( + ac_coords_vox, pc_coords_vox = localize_ac_pc( np.asarray(orig.dataobj), aseg_img, _orig2midslab_vox2vox(additional_context=2), @@ -783,8 +806,8 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: ) cc_fn_seg_labels, cc_fn_softlabels = segment_cc( midslices, - ac_coords, - pc_coords, + ac_coords_vox, + pc_coords_vox, aseg_img, _model_segmentation.result(), ) @@ -805,15 +828,17 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: logger.info(f"Processing slices with selection mode: {slice_selection}") slice_results, slice_io_futures = recon_cc_surf_measures_multi( segmentation=cc_fn_seg_labels, - upright_affine_header=(fsavg_vox2ras, orig.header), slice_selection=slice_selection, + upright_header=fsavg_header, + fsavg2midslab_vox2vox=fsavg2midslab_vox2vox, fsavg_vox2ras=fsavg_vox2ras, + orig2fsavg_vox2vox=orig2fsavg_vox2vox, midslices=midslices, - ac_coords=ac_coords, - pc_coords=pc_coords, + ac_coords_vox=ac_coords_vox, + pc_coords_vox=pc_coords_vox, num_thickness_points=num_thickness_points, subdivisions=subdivisions, - subdivision_method=subdivision_method, + subdivision_method=cast(SubdivisionMethod, subdivision_method), contour_smoothing=contour_smoothing, vox_size=vox_size, subject_dir=sd, @@ -857,7 +882,7 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: orig_space_segmentation_path=sd.filename_by_attribute("cc_orig_segfile"), orig2slab_vox2vox=_orig2midslab_vox2vox(), cc_subseg_midslice=cc_subseg_midslice, - orig2midslice_vox2vox=affine_x_offset(0) @ orig2fsavg_vox2vox, # orig2fsavg, then full2midslice + orig2midslice_vox2vox=orig2midslice_vox2vox, )) METRICS = [ @@ -890,7 +915,8 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: voxel_size=vox_size, # in LIA order ) logger.info(f"CC volume voxel: {cc_volume_voxel}") - # FIXME: Create a proper mesh and use cc_mesh.volume for this volume + # FIXME: Create a proper mesh and use cc_mesh.volume for this volume --> not closed, but move function to + # CCContour? try: cc_volume_contour = segmentation_postprocessing.get_cc_volume_contour( cc_contours=outer_contours, @@ -906,8 +932,8 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: additional_metrics["cc_5mm_volume_pv_corrected"] = cc_volume_contour # get ac and pc in all spaces - ac_coords_3d = np.hstack((FSAVERAGE_MIDDLE, ac_coords)) - pc_coords_3d = np.hstack((FSAVERAGE_MIDDLE, pc_coords)) + ac_coords_3d = np.hstack((FSAVERAGE_MIDDLE, ac_coords_vox)) + pc_coords_3d = np.hstack((FSAVERAGE_MIDDLE, pc_coords_vox)) standardized2orig_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig = ( calc_mapping_to_standard_space(orig, ac_coords_3d, pc_coords_3d, orig2fsavg_vox2vox) ) @@ -933,14 +959,14 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: save_cc_measures_json, sd.filename_by_attribute('cc_mid_measures'), output_metrics_middle_slice | additional_metrics, - )) + )) if sd.has_attribute("cc_measures"): io_futures.append(thread_executor().submit( save_cc_measures_json, sd.filename_by_attribute("cc_measures"), per_slice_output_dict | additional_metrics, - )) + )) # save lta to fsaverage space @@ -949,7 +975,7 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: logger.info(f"Saving LTA to fsaverage space: {sd.filename_by_attribute('upright_lta')}") io_futures.append(thread_executor().submit( write_lta, - sd.filename_by_attribute("upright_lta"), + sd.filename_by_attribute("upright_lta"), orig2fsavg_ras2ras, sd.filename_by_attribute("aseg_name"), aseg_img.header, @@ -983,6 +1009,31 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: logger.info(f"CorpusCallosum analysis pipeline completed successfully in {duration:.2f} seconds.") +def init_mgh_header(header: nibabelHeader, header_dict: MGHHeaderDict) -> MGHHeader: + """ + Generates a MGHHeader object from a header and a header dictionary. + + Parameters + ---------- + header : nibabelHeader + The header object used to initialize the generated header. + header_dict : MGHHeaderDict + A dictionary of values to overwrite in the generated header. + + Returns + ------- + MGHHeader + The header updated with values in header_dict. + """ + new_header: MGHHeader = MGHHeader.from_header(header) + if "dims" in header_dict: + new_header["dims"] = np.append(header_dict["dims"], [1]) + for key in ("delta", "Pxyz_c", "Mdc"): + if key in header_dict: + new_header[key] = header_dict[key] + return new_header + + def save_cc_measures_json(cc_mid_measure_file: Path, metrics: dict[str, object]): """Save JSON metrics file.""" # Convert numpy arrays to lists for JSON serialization @@ -1002,7 +1053,6 @@ def save_cc_measures_json(cc_mid_measure_file: Path, metrics: dict[str, object]) conf_name=options.conf_name, aseg_name=options.aseg_name, subject_dir=options.subject_dir, - #FIXME: slice_selection is True/bool slice_selection=options.slice_selection, num_thickness_points=options.num_thickness_points, subdivisions=list(options.subdivisions), diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index fb62ca94..68b238fb 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -383,7 +383,9 @@ def get_cc_volume_contour( using Simpson's rule. If the CC width is larger than desired_width_mm, the voxels on the edges are calculated as partial volumes to achieve the desired width. """ - # FIXME: This function is a shape-tool, it should therefore not be in segmentation.postprocessing... + # FIXME: move to CCContour --> area + + # FIXME: this code currently produces volume estimates more that 50% off of the volume_based estimate in # get_cc_volume_voxel... @@ -421,7 +423,7 @@ def get_cc_volume_contour( measurement_points = np.arange(-voxel_size[0]*(areas.shape[0]//2), voxel_size[0]*((areas.shape[0]+1)//2), lr_spacing) - # FIXME: why interpolate at 0.25? Also, why do we need interpolaton at all? + # FIXME: why interpolate at 0.25? Also, why do we need interpolation at all? # interpolate areas at 0.25 and 5 areas_interpolated = np.interp(x=[-2.5, 2.5], xp=measurement_points, diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py index 272b7a5f..97ca7be6 100644 --- a/CorpusCallosum/shape/contour.py +++ b/CorpusCallosum/shape/contour.py @@ -27,7 +27,7 @@ import re from pathlib import Path -from typing import Literal +from typing import TypeVar import lapy import matplotlib @@ -37,12 +37,17 @@ from scipy.ndimage import gaussian_filter1d import FastSurferCNN.utils.logging as logging -from CorpusCallosum.shape.endpoint_heuristic import smooth_contour +from CorpusCallosum.shape.endpoint_heuristic import find_cc_endpoints, smooth_contour from CorpusCallosum.shape.thickness import cc_thickness, make_mesh_from_contour +from CorpusCallosum.utils.types import Points2dType +from FastSurferCNN.utils import AffineMatrix4x4, Mask2d, Vector2d logger = logging.get_logger(__name__) +Self = TypeVar("Self", bound="CCContour") + +# FIXME: Maybe CCContur should inherit from Polygon at a later date? class CCContour: """A class for representing and manipulating corpus callosum (CC) contours. @@ -50,7 +55,7 @@ class CCContour: Attributes ---------- - contour : np.ndarray + points : np.ndarray Array of shape (N, 2) containing 2D contour points. thickness_values : np.ndarray Array of shape (N,) for thickness measurements for each contour point. @@ -63,52 +68,61 @@ class CCContour: >>> >>> contour = CCContour(contour_points, thickness_values, >>> endpoint_idxs=(anterior_idx, posterior_idx), - >>> resolution=1.0) + >>> z_position=0.0) >>> contour.fill_thickness_values() # interpolate missing values >>> contour.smooth_contour(window_size=5) >>> contour.save_contour("contour_0.txt") >>> contour.save_thickness_values("thickness_values_0.txt") >>> contour.save_thickness_measurement_points("thickness_measurement_points_0.txt") """ + + def __init__( self, - contour: np.ndarray[tuple[Literal["N", 2]], np.dtype[np.float_]], - thickness_values: np.ndarray[tuple[Literal["N"]], np.dtype[np.float_]], + points: Points2dType, + thickness_values: np.ndarray[tuple[int], np.dtype[np.float_]] | None, endpoint_idxs: tuple[int, int] | None = None, - resolution: float = 1.0 + z_position: float = 0.0 ): """Initialize a CCContour object. Parameters ---------- - contour : np.ndarray + points : np.ndarray Array of shape (N, 2) containing 2D contour points. - thickness_values : np.ndarray + thickness_values : np.ndarray, optional Array of thickness measurements for each contour point. endpoint_idxs : tuple[int, int], optional Tuple containing start and end indices for the contour. - resolution : float, default=1.0 - The left-right spacing. + z_position : float, default=0.0 + The distance of the slice from midslice. """ - self.contour = contour - if self.contour.shape[1] != 2: - raise ValueError(f"Contour must be a 2D array, but is {self.contour.shape}") + self.points = points + if self.points.ndim != 2 or self.points.shape[1] != 2: + raise ValueError(f"Contour must be a (N, 2) array, but is {self.points.shape}") self.thickness_values = thickness_values - if self.contour.shape[0] != len(thickness_values): - raise ValueError( - f"Number of contour points ({self.contour.shape[0]}) does not match number of thickness values " - f"({len(thickness_values)})", - ) - # write vertex indices where thickness values are not nan - self.original_thickness_vertices = np.where(~np.isnan(thickness_values))[0] - self.resolution = resolution + if thickness_values is not None: + if self.points.shape[0] != len(thickness_values): + raise ValueError( + f"Number of contour points ({self.points.shape[0]}) does not match number of thickness values " + f"({len(thickness_values)})", + ) + # write vertex indices where thickness values are not nan + self.original_thickness_vertices = np.where(~np.isnan(thickness_values))[0] + else: + self.original_thickness_vertices = None + self.z_position = z_position if endpoint_idxs is None: - self.endpoint_idxs = (0, len(contour) // 2) + self.endpoint_idxs = (0, len(points) // 2) else: self.endpoint_idxs = endpoint_idxs + def __len__(self) -> int: + """Return the number of points on the contour.""" + return len(self.points) + def smooth_contour(self, window_size: int = 5) -> None: """Smooth a contour using a moving average filter. @@ -119,19 +133,14 @@ def smooth_contour(self, window_size: int = 5) -> None: Notes ----- - Uses smooth_contour from cc_endpoint_heuristic module to: - 1. Extract x and y coordinates. - 2. Apply moving average smoothing. - 3. Update contour with smoothed coordinates. + Uses smooth_contour from cc_endpoint_heuristic module. """ - x, y = self.contour.T - x, y = smooth_contour(x, y, window_size) - self.contour = np.array([x, y]).T + self.points = np.array([smooth_contour(*self.points.T, window_size=window_size)]).T def copy(self) -> "CCContour": """Copy the contour. """ - return CCContour(self.contour.copy(), self.thickness_values.copy(), self.endpoint_idxs, self.resolution) + return CCContour(self.points.copy(), self.thickness_values.copy(), self.endpoint_idxs, self.z_position) def get_contour_edge_lengths(self) -> np.ndarray: """Get the lengths of the edges of a contour. @@ -146,22 +155,29 @@ def get_contour_edge_lengths(self) -> np.ndarray: Edge lengths are calculated as Euclidean distances between consecutive points in the contour. """ - edges = np.diff(self.contour, axis=0) + edges = np.diff(self.points, axis=0) return np.sqrt(np.sum(edges**2, axis=1)) - def create_levelpaths(self, - num_points: int, - update_data: bool = True - ) -> tuple[list[np.ndarray], list[float]]: - midline_len, thickness, curvature, midline_equi, \ - levelpaths, contour_with_thickness, endpoint_idxs = cc_thickness( - self.contour, + def create_levelpaths( + self, + num_points: int, + update_data: bool = True + ) -> tuple[list[np.ndarray], float]: + #FIXME: docstring + + # FIXME: cache all these values in CCContour, and invalidate the cache, when either points or endpoint_idxs get + # changed; alternatively, make points and endpoint_idxs read_only (by creating getter-only properties) + # and have all functions that change points or endpoints return a new CCContour object instead. + midline_len, thickness, curvature, midline_equi, levelpaths, contour_with_thickness, endpoint_idxs = \ + cc_thickness( + self.points, self.endpoint_idxs, n_points=num_points, ) if update_data: - self.contour = contour_with_thickness[:, :2] + # FIXME: as an alternative to update_data, use "inplace" ; always return the CCContour object? + self.points = contour_with_thickness[:, :2] self.thickness_values = contour_with_thickness[:,2] self.original_thickness_vertices = np.where(~np.isnan(self.thickness_values))[0] self.endpoint_idxs = endpoint_idxs @@ -176,12 +192,20 @@ def set_thickness_values(self, thickness_values: np.ndarray, use_measurement_poi ---------- thickness_values : np.ndarray Array of thickness values for the contour. - use_measurement_points : bool, optional - Whether to use the measurement points to set the thickness values, by default False. + use_measurement_points : bool, default=False + Whether to use the measurement points to set the thickness values. """ if use_measurement_points: - if len(thickness_values) == len(self.original_thickness_vertices): - self.thickness_values = np.full(len(self.contour), np.nan) + if self.original_thickness_vertices is None: + if len(thickness_values) != len(self.points): + raise ValueError( + f"Thickness values not initialized and number of points in the contour {len(self.points)} does " + f"not match number of thickness values {len(thickness_values)}.", + ) + self.original_thickness_vertices = np.where(~np.isnan(thickness_values))[0] + self.thickness_values = thickness_values + elif len(thickness_values) == len(self.original_thickness_vertices): + self.thickness_values = np.full(len(self.points), np.nan) self.thickness_values[self.original_thickness_vertices] = thickness_values else: raise ValueError( @@ -189,10 +213,10 @@ def set_thickness_values(self, thickness_values: np.ndarray, use_measurement_poi f"{len(self.original_thickness_vertices)}.", ) else: - if len(thickness_values) != len(self.contour): + if len(thickness_values) != len(self.points): raise ValueError( f"The number of thickness values does not match number of points in the contour " - f"{len(self.contour)}.", + f"{len(self.points)}.", ) self.thickness_values = thickness_values @@ -291,28 +315,21 @@ def plot_contour(self, output_path: str | None = None) -> None: if output_path is not None: self.__make_parent_folder(output_path) - contour = self.contour - plt.figure(figsize=(10, 10)) - # Get thickness values for this slice - thickness = self.thickness_values # Plot points with colors based on thickness - for i in range(len(contour)): - if np.isnan(thickness[i]): - plt.plot(contour[i, 0], contour[i, 1], "o", color="gray", markersize=1) - else: - # Map thickness to color from red to yellow - plt.plot( - contour[i, 0], - contour[i, 1], - "o", - color=plt.cm.YlOrRd(thickness[i] / np.nanmax(thickness)), - markersize=1, - ) + gray_points = np.isnan(self.thickness_values) + if np.any(gray_points): + plt.plot(*self.points[gray_points, :2].T, "o", color="gray", markersize=1) + + if not np.all(gray_points): + not_gray = np.logical_not(gray_points) + color_values = plt.cm.YlOrRd(self.thickness_values[not_gray] / np.nanmax(self.thickness_values[not_gray])) + # Map thickness to color from red to yellow + plt.plot(*self.points[~gray_points, :2].T, "o", color=color_values, markersize=1) # Connect points with lines - plt.plot(contour[:, 0], contour[:, 1], "-", color="black", alpha=0.3, label="Contour") + plt.plot(self.points[:, 0], self.points[:, 1], "-", color="black", alpha=0.3, label="Contour") plt.axis("equal") plt.xlabel("X") plt.ylabel("Y") @@ -325,7 +342,6 @@ def plot_contour(self, output_path: str | None = None) -> None: else: plt.show() - def plot_contour_colorfill( self, plot_values: np.ndarray, @@ -359,14 +375,14 @@ def plot_contour_colorfill( """ plot_values = plot_values[::-1] # make sure values are plotted left to right (anterior to posterior) - points, _ = make_mesh_from_contour(self.contour, max_volume=0.5, min_angle=25, verbose=False) + points, _ = make_mesh_from_contour(self.points, max_volume=0.5, min_angle=25, verbose=False) # make points 3D by adding zero points = np.column_stack([points, np.zeros(len(points))]) - levelpaths, _ = self.create_levelpaths(num_points=len(plot_values)-1, update_data=False) + levelpaths, *_ = self.create_levelpaths(num_points=len(plot_values)-1, update_data=False) - outside_contour = self.contour.T + outside_contour = self.points.T # Create a grid of points covering the contour area with higher resolution x_min, x_max = np.min(outside_contour[0]), np.max(outside_contour[0]) @@ -541,8 +557,7 @@ def __make_parent_folder(filename: Path | str) -> None: Notes ----- - Creates parent directory with parents=False to avoid creating - multiple levels of directories unintentionally. + Creates parent directory with parents=False to avoid creating multiple levels of directories unintentionally. """ Path(filename).parent.mkdir(parents=False, exist_ok=True) @@ -570,16 +585,26 @@ def save_contour(self, output_path: Path | str) -> None: f"posterior_endpoint_idx={self.endpoint_idxs[1]}\n" ) f.write("x,y\n") - for point in self.contour: + for point in self.points: f.write(f"{point[0]},{point[1]}\n") - def load_contour(self, input_path: str) -> None: + @classmethod + def from_contour_file( + cls: type[Self], + input_path: str | Path, + thickness_values_path: str | Path, + z_position: float = 0.0, + ) -> Self: """Load contour from a CSV file. Parameters ---------- - input_path : str + input_path : str, Path Path to the CSV file containing the contours. + thickness_values_path : str, Path + Path to the CSV file containing the thickness_values. + z_position : float, default=0.0 + The distance to the midslice (in fsaverage space). Raises ------ @@ -595,8 +620,7 @@ def load_contour(self, input_path: str) -> None: 4. Converts lists to fixed-size arrays with None padding. """ current_points = [] - self.contours = [] - self.endpoint_idxs = [] + endpoint_idxs = [] with open(input_path) as f: header = next(f).strip() @@ -607,7 +631,7 @@ def load_contour(self, input_path: str) -> None: anterior_idx = int(anterior_match.group(1)) posterior_idx = int(posterior_match.group(1)) - self.endpoint_idxs = (anterior_idx, posterior_idx) + endpoint_idxs = (anterior_idx, posterior_idx) # Skip column names next(f) @@ -615,7 +639,12 @@ def load_contour(self, input_path: str) -> None: for line in f: x, y = line.strip().split(",") current_points.append([float(x), float(y)]) - self.contour = np.array(current_points) + contour = np.array(current_points) + if thickness_values_path: + thickness_values = cls._load_thickness_values(contour, None, thickness_values_path) + else: + thickness_values = None + return CCContour(contour, thickness_values, endpoint_idxs, z_position=z_position) def save_thickness_values(self, output_path: Path | str) -> None: """Save thickness values to a CSV file. @@ -641,23 +670,33 @@ def save_thickness_values(self, output_path: Path | str) -> None: def load_thickness_values( self, - input_path: str, + input_path: str | Path, ) -> None: """Load thickness values from a CSV file. Parameters ---------- - input_path : str + input_path : Path, str Path to the CSV file containing thickness values. - original_thickness_vertices_path : str or None, optional - Path to a file containing the indices of vertices where thickness - was measured, by default None. Raises ------ ValueError If number of thickness values doesn't match measurement points or if number of slices is inconsistent. + """ + self.thickness_values = self._load_thickness_values(self.points, self.original_thickness_vertices, input_path) + + @classmethod + def _load_thickness_values( + cls, + contour: Points2dType, + original_thickness_vertices: np.ndarray[tuple[int], np.dtype[np.signedinteger]] | None, + input_path: str | Path, + ) -> np.ndarray[tuple[int], np.dtype[np.float_]]: + """See load_thickness_values. + + Ignore shape of thickness values if original_thickness_vertices is None. Notes ----- @@ -666,8 +705,6 @@ def load_thickness_values( 2. Groups values by slice index. 3. Optionally associates values with specific vertices. 4. Handles both full contour and profile measurements. - - """ data = np.loadtxt(input_path, delimiter=",", skiprows=1) if data.ndim == 0: @@ -677,18 +714,71 @@ def load_thickness_values( else: raise ValueError("Thickness values file must contain a single column") - if len(values) != len(self.contour): - if np.sum(~np.isnan(values)) == len(self.original_thickness_vertices): - new_values = np.full(len(self.contour), np.nan) - new_values[self.original_thickness_vertices] = values[~np.isnan(values)] + if len(values) != len(contour): + if original_thickness_vertices is None: + new_values = values + elif np.sum(~np.isnan(values)) == len(original_thickness_vertices): + new_values = np.full(len(contour), np.nan) + new_values[original_thickness_vertices] = values[~np.isnan(values)] else: raise ValueError( - f"Number of thickness values {len(values)} does not match number of points in the " - f"contour {len(self.contour)} and current number of measururement points " - f"{len(self.original_thickness_vertices)} does not match the number of set thickness values " - f"{np.sum(~np.isnan(values))}." + f"Number of thickness values {len(values)} does not match number of points in the contour " + f"{len(contour)} and current number of measurement points {len(original_thickness_vertices)} does " + f"not match the number of set thickness values {np.sum(~np.isnan(values))}." ) else: raise ValueError(f"Number of thickness values in {input_path} does not match the vertices of the path!") - self.thickness_values = new_values + return new_values + + @classmethod + def from_mask_and_acpc( + cls: type[Self], + cc_mask: Mask2d, + ac_2d: Vector2d, + pc_2d: Vector2d, + slice_vox2ras: AffineMatrix4x4, + contour_smoothing: int = 5 + ) -> Self: + """Extracts the contour of the CC using marching squares, smooth and transform to RAS coordinates. + + Parameters + ---------- + cc_mask : np.ndarray of shape (H, W) and type bool + Binary mask of the corpus callosum. + ac_2d : np.ndarray of shape (2,) and type float + 2D voxel coordinates of the anterior commissure. + pc_2d : np.ndarray of shape (2,) and type float + 2D voxel coordinates of the posterior commissure. + slice_vox2ras : AffineMatrix4x4 + Transformation matrix from slice-voxel space to RAS-coordinates. + contour_smoothing : int, default=5 + Window size for contour smoothing. + + Returns + ------- + contour : CCContour + The contour object. + + Notes + ----- + Expects LIA orientation. + """ + import skimage.measure + + from CorpusCallosum.shape.endpoint_heuristic import smooth_contour + + contour = skimage.measure.find_contours(cc_mask, level=0.5)[0].T + #FIXME: maybe use Polygon.smooth_* + contour = np.array(smooth_contour(*contour, window_size=contour_smoothing)) + # Add z=0 coordinate to make 3D, then remove it after resampling + contour_3d = np.concatenate([np.zeros((1, contour.shape[1])), contour]) # ZIA, (3, N) + # FIXME: change this to using Polygon class when we upgrade lapy + contour_3d = lapy.tria_mesh.TriaMesh._TriaMesh__resample_polygon(contour_3d.T, 701).T + contour_ras = (slice_vox2ras[:3, :3] @ contour_3d) + slice_vox2ras[:3, [3]] + + ac_pc_3d = np.concatenate([[[0, 0]], np.stack([ac_2d, pc_2d], axis=1)]) # (3, 2) + ac_ras, pc_ras = ((slice_vox2ras[:3, :3] @ ac_pc_3d) + slice_vox2ras[:3, [3]]).T + endpoint_idx = find_cc_endpoints(contour_ras[1:], ac_ras[1:], pc_ras[1:]) + + return cls(contour_ras[1:].T, None, endpoint_idx, z_position=slice_vox2ras[0, 3]) diff --git a/CorpusCallosum/shape/endpoint_heuristic.py b/CorpusCallosum/shape/endpoint_heuristic.py index 2c89bade..f9dc084c 100644 --- a/CorpusCallosum/shape/endpoint_heuristic.py +++ b/CorpusCallosum/shape/endpoint_heuristic.py @@ -13,13 +13,14 @@ # limitations under the License. from typing import Literal, overload -import lapy +import lapy.tria_mesh import numpy as np import scipy.ndimage import skimage.measure from scipy.ndimage import label -from FastSurferCNN.utils import Mask2d, Vector2d +from CorpusCallosum.utils.types import Points2dType, Polygon2dType +from FastSurferCNN.utils import Image2d, Mask2d, Vector2d def smooth_contour(x: np.ndarray, y: np.ndarray, window_size: int) -> tuple[np.ndarray, np.ndarray]: @@ -68,7 +69,7 @@ def smooth_contour(x: np.ndarray, y: np.ndarray, window_size: int) -> tuple[np.n return x_smoothed, y_smoothed -def connect_diagonally_connected_components(cc_mask: np.ndarray) -> None: +def connect_diagonally_connected_components(cc_mask: Image2d) -> Image2d: """Connect diagonally connected components in the CC mask. Parameters @@ -105,8 +106,7 @@ def connect_diagonally_connected_components(cc_mask: np.ndarray) -> None: ((down_left > 0) & ((right > 0) | (up > 0))) | ((down_right > 0) & ((left > 0) | (up > 0))) ) - - + # Get connected components before filling using 4-connectivity # This way, diagonal-only connections are treated as separate components structure_4conn = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]) @@ -115,25 +115,27 @@ def connect_diagonally_connected_components(cc_mask: np.ndarray) -> None: # For each potential gap, check if filling it would reduce the number of components connects_diagonals = np.zeros_like(potential_diagonal_gaps) gap_positions = np.where(potential_diagonal_gaps) - - for i, j in zip(gap_positions[0], gap_positions[1], strict=True): - # Temporarily fill this gap + + if len(gap_positions[0]) > 0: test_mask = cc_mask.copy() - test_mask[i, j] = 1 - - # Check connected components after filling - _, num_components_after = label(test_mask, structure=structure_4conn) - - # Only fill if it actually connects previously disconnected components - if num_components_after < num_components_before: - connects_diagonals[i, j] = True + # Fill all gap voxels, that by themselves would connect 2 components + for i, j in zip(gap_positions[0], gap_positions[1], strict=True): + # Temporarily fill this gap + test_mask[i, j] = 1 + # Check connected components after filling, this is relatively slow... + _, num_components_after = label(test_mask, structure=structure_4conn) + # Only fill if it actually connects previously disconnected components + if num_components_after < num_components_before: + connects_diagonals[i, j] = True + # Revert temporary fill + test_mask[i, j] = cc_mask[i, j] # Fill the identified diagonal gaps that actually improve connectivity - cc_mask[connects_diagonals] = 1 + return np.where(connects_diagonals, 1, cc_mask) -def extract_cc_contour(cc_mask: np.ndarray, contour_smoothing: int = 5) -> np.ndarray: - """Extract the contour of the CC from the mask. +def extract_cc_contour(cc_mask: Mask2d, contour_smoothing: int = 5) -> Polygon2dType: + """Extract the contour of the CC from the mask using a marching squares approach. Parameters ---------- @@ -144,31 +146,19 @@ def extract_cc_contour(cc_mask: np.ndarray, contour_smoothing: int = 5) -> np.nd Returns ------- - np.ndarray - Array of shape (2, N) containing x,y coordinates of the contour points. + lapy.polygon.Polygon + A lapy Polygon object with a closed polygon contour. """ - # cc_mask_orig = cc_mask - cc_mask = cc_mask.copy() - - connect_diagonally_connected_components(cc_mask) + cc_mask = connect_diagonally_connected_components(cc_mask) contour = skimage.measure.find_contours(cc_mask, level=0.5)[0].T contour = np.array(smooth_contour(contour[0], contour[1], contour_smoothing)) - # plot contour - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots(1,2,figsize=(10, 8)) - # ax[0].imshow(cc_mask_orig) - # ax[1].imshow(cc_mask) - # ax[0].plot(contour[1], contour[0], 'r-') - # ax[1].plot(contour[1], contour[0], 'r-') - # plt.show() - return contour @overload -def get_endpoints( +def find_contour_and_endpoints( cc_mask: Mask2d, ac_2d: Vector2d, pc_2d: Vector2d, @@ -179,7 +169,7 @@ def get_endpoints( @overload -def get_endpoints( +def find_contour_and_endpoints( cc_mask: Mask2d, ac_2d: Vector2d, pc_2d: Vector2d, @@ -189,7 +179,7 @@ def get_endpoints( ) -> tuple[np.ndarray, tuple[int, int]]: ... -def get_endpoints( +def find_contour_and_endpoints( cc_mask: Mask2d, ac_2d: Vector2d, pc_2d: Vector2d, @@ -197,16 +187,16 @@ def get_endpoints( return_coordinates: bool = False, contour_smoothing: int = 5 ): - """Determine endpoints of CC by finding points closest to AC and PC. + """Extracts the contour of the CC, rotates to AC-PC alignment, and determines closest points of CC to AC and PC. Parameters ---------- cc_mask : np.ndarray of shape (H, W) and type bool Binary mask of the corpus callosum. ac_2d : np.ndarray of shape (2,) and type float - 2D coordinates of the anterior commissure. + 2D voxel coordinates of the anterior commissure. pc_2d : np.ndarray of shape (2,) and type float - 2D coordinates of the posterior commissure. + 2D voxel coordinates of the posterior commissure. resolution : pair of floats Inslice image resolution in mm (inferior/superior and anterior/posterior directions). return_coordinates : bool, default=False @@ -217,7 +207,8 @@ def get_endpoints( Returns ------- contour_rotated : np.ndarray - The contour rotated to AC-PC alignment. + The contour in 2d voxel coordinates rotated to AC-PC alignment and with origin at center of image + (axis 0: I->S, axis 1: A->P). anterior_posterior_point_indices : pair of ints Indices of anterior and posterior points in the contour. anterior_posterior_point_coordinates : tuple[np.ndarray, np.ndarray] @@ -239,9 +230,24 @@ def get_endpoints( # Convert symbolic theta to float and convert from radians to degrees theta_degrees = theta * 180 / np.pi - rotated_cc_mask = scipy.ndimage.rotate(cc_mask, -theta_degrees, order=0, reshape=False) + # FIXME: Why do we rotate the mask before we do the marching squares? Instead of all this weird rotation everywhere + # here, it seems to me it would be the same to just rotate the offsets for the heuristic?! Like so: + # + # rot_matrix_inv = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + # rotated_pc_2d = pc_2d.astype(float) + rot_matrix_inv @ np.array([10, -5]) / resolution + # rotated_ac_2d = ac_2d.astype(float) + rot_matrix_inv @ np.array([0, 5]) / resolution + # + # contour = extract_cc_contour(cc_mask, contour_smoothing) + # # Add z=0 coordinate to make 3D, then remove it after resampling + # contour_3d = np.vstack([contour, np.zeros(contour.shape[1])]) + # contour_3d = __resample_polygon(contour_3d.T, 701).T + # contour = contour_3d[:2, :-1] + # # find point in contour closest to AC + # ac_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_ac_2d[:, None], axis=0)) + # # find point in contour closest to PC + # pc_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_pc_2d[:, None], axis=0)) - contour = extract_cc_contour(rotated_cc_mask, contour_smoothing) + rotated_cc_mask = scipy.ndimage.rotate(cc_mask, -theta_degrees, order=0, reshape=False) # rotate points around center origin_point = np.array([image_size[0] // 2, image_size[1] // 2]) @@ -250,32 +256,28 @@ def get_endpoints( rot_matrix = np.array([[np.cos(-theta), -np.sin(-theta)], [np.sin(-theta), np.cos(-theta)]]) # Translate points to origin, rotate, then translate back - pc_centered = pc_2d - origin_point - ac_centered = ac_2d - origin_point - - rotated_pc_2d = (rot_matrix @ pc_centered) + origin_point - rotated_ac_2d = (rot_matrix @ ac_centered) + origin_point - - # Add z=0 coordinate to make 3D, then remove it after resampling - contour_3d = np.vstack([contour, np.zeros(contour.shape[1])]) - contour_3d = lapy.tria_mesh.TriaMesh._TriaMesh__resample_polygon(contour_3d.T, 701).T - contour = contour_3d[:2] - - contour = contour[:, :-1] + rotated_pc_2d = rot_matrix @ (pc_2d.astype(float) - origin_point) + origin_point + rotated_ac_2d = rot_matrix @ (ac_2d.astype(float) - origin_point) + origin_point - rotated_ac_2d = np.array(rotated_ac_2d).astype(float) - rotated_pc_2d = np.array(rotated_pc_2d).astype(float) - - # move posterior commisure 5 mm posterior - # FIXME: why is the move 10mm inferior not commented? + # move posterior commisure 5 mm posterior, 10 mm superior # FIXME: multiplication means moving less for smaller voxels, why not division? # changed to division, 5 mm / voxel size => number of voxels to move + # ----> CHECK IF THESE VALUES ARE CONFIRMED GOOD IN TESTING rotated_pc_2d = rotated_pc_2d + np.array([10, -5]) / resolution - # move anterior commisure 1.5 mm anterior - # FIXME: why does the documentation say 1.5mm when the code says 5mm? + # move anterior commisure 5 mm anterior rotated_ac_2d = rotated_ac_2d + np.array([0, 5]) / resolution + contour = extract_cc_contour(rotated_cc_mask, contour_smoothing) + + # Add z=0 coordinate to make 3D, then remove it after resampling + #FIXME: change this to using Polygon class when we upgrade lapy + contour_3d = lapy.tria_mesh.TriaMesh._TriaMesh__resample_polygon( + np.append(contour, np.zeros((1, contour.shape[1])), axis=0).T, + 701, + ) + contour = contour_3d.T[:2, :-1] + # find point in contour closest to AC ac_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_ac_2d[:, None], axis=0)) @@ -297,3 +299,71 @@ def get_endpoints( return contour_rotated, (ac_startpoint_idx, pc_startpoint_idx), (start_point_ac, start_point_pc) else: return contour_rotated, (ac_startpoint_idx, pc_startpoint_idx) + +def find_cc_endpoints( + contour: Points2dType, + ac_2d: Vector2d, + pc_2d: Vector2d, + return_coordinates: bool = False, +): + """Extracts the contour of the CC, rotates to AC-PC alignment, and determines closest points of CC to AC and PC. + + Parameters + ---------- + contour : np.ndarray of shape (2, N) + Points of the CC contour in AS (millimeter). + ac_2d : np.ndarray of shape (2,) and type float + 2D AS coordinates of the anterior commissure in millimeter. + pc_2d : np.ndarray of shape (2,) and type float + 2D AS coordinates of the posterior commissure in millimeter. + return_coordinates : bool, default=False + If True, return endpoint coordinates. + + Returns + ------- + anterior_posterior_point_indices : pair of ints + Indices of anterior and posterior points in the contour. + anterior_posterior_point_coordinates : pair of Vector2d + Only if return_coordinates is True: Coordinates of anterior and posterior points, each shape (2,). + + Notes + ----- + Expects AS orientation of contour, ac_2d, and pc_2d. + """ + if contour.shape[0] != 2: + raise ValueError(f"contour must have shape (2, N), got {contour.shape}") + if any(p2d.shape != (2,) for p2d in (ac_2d, pc_2d)): + raise ValueError(f"ac_2d and pc_2d must have shape (2,), got {ac_2d.shape} and {pc_2d.shape}") + + # Calculate angle between AC-PC line and horizontal using numpy + ac_pc_vector = pc_2d - ac_2d + horizontal_vector = np.array([0, -20]) + # Calculate angle using dot product formula: cos(theta) = (a·b)/(|a||b|) + dot_product = np.dot(ac_pc_vector, horizontal_vector) + norms = np.linalg.norm(ac_pc_vector) * np.linalg.norm(horizontal_vector) + # The sign of theta is the inverse of ac_pc_vector [ X ] + theta = -np.sign(ac_pc_vector[0]) * np.arccos(dot_product / norms) + + rot_matrix_inv = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + # move posterior commisure 5 mm posterior, 10 mm inferior + # FIXME: multiplication means moving less for smaller voxels, why not division? + # changed to division, 5 mm / voxel size => number of voxels to move + # ----> CHECK IF THESE VALUES ARE CONFIRMED GOOD IN TESTING + as_offset_pc = np.array([-5, -10], dtype=float) + rotated_pc_2d = pc_2d.astype(float) + rot_matrix_inv @ as_offset_pc + # move anterior commisure 5 mm anterior + as_offset_ac = np.array([5, 0], dtype=float) + rotated_ac_2d = ac_2d.astype(float) + rot_matrix_inv @ as_offset_ac + + # Find the endpoints of the CC shape relative to AC and PC coordinates + # find point in contour closest to AC + ac_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_ac_2d[:, None], axis=0)) + # find point in contour closest to PC + pc_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_pc_2d[:, None], axis=0)) + + if return_coordinates: + start_point_ac, start_point_pc = contour[:, [ac_startpoint_idx, pc_startpoint_idx]].T + + return (ac_startpoint_idx, pc_startpoint_idx), (start_point_ac, start_point_pc) + else: + return ac_startpoint_idx, pc_startpoint_idx diff --git a/CorpusCallosum/shape/mesh.py b/CorpusCallosum/shape/mesh.py index ca4325f0..9d3502f0 100644 --- a/CorpusCallosum/shape/mesh.py +++ b/CorpusCallosum/shape/mesh.py @@ -20,14 +20,15 @@ import nibabel as nib import numpy as np import plotly.graph_objects as go +from lapy import TriaMesh from plotly.io import write_html as plotly_write_html from scipy.ndimage import gaussian_filter1d import FastSurferCNN.utils.logging as logging from CorpusCallosum.shape.contour import CCContour from CorpusCallosum.shape.thickness import make_mesh_from_contour -from FastSurferCNN.utils import nibabelImage -from FastSurferCNN.utils.common import suppress_stdout +from FastSurferCNN.utils import AffineMatrix4x4, nibabelImage +from FastSurferCNN.utils.common import suppress_stdout, update_docstring try: from pyrr import Matrix44 @@ -199,15 +200,14 @@ class CCMesh(lapy.TriaMesh): Triangle indices of the mesh. mesh_vertex_colors : np.ndarray Vertex values for each vertex (CC thickness values) - resolution : float - Spatial resolution of the mesh in millimeters. """ - def __init__(self, - vertices: list | np.ndarray, - faces: list | np.ndarray, - vertex_values: list | np.ndarray | None = None, - resolution: float = 1.0): + def __init__( + self, + vertices: list | np.ndarray, + faces: list | np.ndarray, + vertex_values: list | np.ndarray | None = None, + ): """Initialize a CC_Mesh object. Parameters @@ -218,12 +218,9 @@ def __init__(self, List of face indices or array of shape (M, 3). vertex_values : list or numpy.ndarray, optional Vertex values for each vertex (CC thickness values) - resolution : float, optional - Spatial resolution of the mesh in millimeters, by default 1.0. """ super().__init__(np.vstack(vertices), np.vstack(faces)) self.mesh_vertex_colors = vertex_values - self.resolution = resolution def plot_mesh( self, @@ -631,16 +628,16 @@ def __make_parent_folder(filename: Path | str) -> None: """ Path(filename).parent.mkdir(parents=False, exist_ok=True) - def to_fs_coordinates( - self, - lr_offset: float, - ) -> "CCMesh": + def to_vox_coordinates( + self: Self, + mesh_ras2vox: AffineMatrix4x4, + ) -> Self: """Convert mesh coordinates to FreeSurfer coordinate system. Parameters ---------- - lr_offset : float - Voxel offset to apply before transformation, this should be often `FSAVERAGE_MIDDLE / vox_size_in_lr`. + mesh_ras2vox : AffineMatrix4x4 + Transformation matrix from midplane mesh space (RAS centered on midplane) to voxel coordinates. Returns ------- @@ -650,19 +647,11 @@ def to_fs_coordinates( Notes ----- Mesh coordinates are in ASR (Anterior-Superior-Right) orientation, with the coordinate system origin on - *the* midslice. The function transforms from midslice ASR to LIA vox coordinates. + *the* midslice. The function *first* transforms from midslice ASR to LIA vox coordinates. """ from copy import copy new_object = copy(self) - asrvox_midslice2orig_vox2vox = np.eye(4) - # to LSA - asrvox_midslice2orig_vox2vox[:, [0, 2]] = asrvox_midslice2orig_vox2vox[:, [2, 0]] - # center LR - asrvox_midslice2orig_vox2vox[0, 3] = lr_offset - # flip SI - asrvox_midslice2orig_vox2vox[:, 1] *= -1 - # to LSA # new_object.v = new_object.v[:, [2, 1, 0]] # to voxel @@ -682,24 +671,13 @@ def to_fs_coordinates( # Torig: mri_info --vox2ras-tkr orig.mgz # https://surfer.nmr.mgh.harvard.edu/fswiki/CoordinateSystems - v_vox = np.concatenate([self.v, np.ones((self.v.shape[0], 1))], axis=1) - new_object.v = (v_vox @ asrvox_midslice2orig_vox2vox.T)[:, :3] + new_object.v = (mesh_ras2vox[:3, :3] @ self.v.T).T + mesh_ras2vox[None, :3, 3] # new_object.v = (vox2ras_tkr @ np.concatenate([self.v, np.ones((self.v.shape[0], 1))], axis=1).T).T[:, :3] return new_object - def write_fssurf(self, filename: Path | str, image: str | object | None = None) -> None: - """Save as Freesurfer Surface Geometry file (wrap Nibabel). - - Parameters - ---------- - filename : str - Filename to save to. - image : str, object, None - Path to image or nibabel image object. If specified, the vertices - are assumed to be in voxel coordinates and are converted - to surface RAS (tkr) coordinates before saving. - The expected order of coordinates is (x, y, z) matching - the image voxel indices. + @update_docstring(parent_doc=TriaMesh.write_fssurf.__doc__) + def write_fssurf(self, filename: Path | str, image: str | nibabelImage | None = None) -> None: + """{parent_doc} Notes ----- @@ -725,7 +703,7 @@ def write_morph_data(self, filename: Path | str) -> None: @classmethod def from_contours( - cls: Self, + cls: type[Self], contours: list[CCContour], lr_center: float = 0, closed: bool = False, @@ -764,38 +742,42 @@ def from_contours( - Creates caps at both ends. - Applies smoothing. - Colors caps based on thickness values. - """ - # Check that all contours have the same resolution - resolution = contours[0].resolution - for idx, contour in enumerate(contours[1:], start=1): - if not np.isclose(contour.resolution, resolution): - raise ValueError( - f"All contours must have the same resolution. " - f"Expected {resolution}, but contour at index {idx} has {contour.resolution}." - ) + z_coordinates = np.array([contour.z_position for contour in contours]) + same_z_position = np.isclose(z_coordinates[:, None], z_coordinates[None, :]) + # filter for diagonal and duplicates + unique_same_z_position = np.logical_and(same_z_position, np.tri(z_coordinates.shape[0], k=-1, dtype=bool).T) + if np.any(unique_same_z_position): + raise ValueError( + f"All contours must have different z_positions, but {np.array(np.where(unique_same_z_position)).T} " + f"have similar z_positions." + ) # Calculate z coordinates for each slice - z_coordinates = (np.arange(len(contours)) - len(contours) // 2) * contours[0].resolution + lr_center + # z_coordinates = (np.arange(len(contours)) - len(contours) // 2) * contours[0].resolution + lr_center # Build vertices list with z-coordinates vertices = [] faces = [] vertex_start_indices = [] # Track starting index for each contour current_index = 0 + previous_contour: CCContour | None = None - for i, contour in enumerate(contours): + for contour in contours: vertex_start_indices.append(current_index) - vertices.append(np.hstack([contour.contour, np.full((len(contour.contour), 1), z_coordinates[i])])) + vertices.append(np.hstack([np.full((len(contour.points), 1), contour.z_position), contour.points])) # Check if there's a next valid contour to connect to - if i + 1 < len(contours): - contour2 = contours[i + 1] - faces_between = make_triangles_between_contours(contour.contour, contour2.contour) + if previous_contour is not None: + if len(previous_contour.points) != len(contour.points): + raise ValueError("The number of points of multiple contours must be the same!") + faces_between = make_triangles_between_contours(previous_contour.points, contour.points) faces.append(faces_between + current_index) - current_index += len(contour.contour) + current_index += len(contour.points) + + previous_contour = contour vertex_values = np.concatenate([contour.thickness_values for contour in contours]) @@ -807,6 +789,9 @@ def from_contours( vertex_values = tmp_mesh.mesh_vertex_colors if closed: + # FIXME: this functionality is untested and not used + logger.warning("CCMesh.from_contours(closed=True) is untested and likely has errors.") + # Close the mesh by creating caps on both ends # Left cap (first slice) - use counterclockwise orientation left_side_points, left_side_trias = make_mesh_from_contour(vertices[: vertex_start_indices[1]][..., :2]) @@ -816,7 +801,7 @@ def from_contours( right_side_points, right_side_trias = make_mesh_from_contour(vertices[vertex_start_indices[-1]:][..., :2]) right_side_points = np.hstack([right_side_points, np.full((len(right_side_points), 1), z_coordinates[-1])]) - #FIXME: Can we remove this if-statement? + # color_sides is a legacy visualization option to allow caps to have thickness colors color_sides = True if color_sides: left_side_points, left_side_trias, left_side_colors = _create_cap( @@ -836,9 +821,9 @@ def from_contours( right_side_trias = right_side_trias + current_index current_index += len(right_side_points) - # FIXME: should this not be a concatenate statements? + # should this not be a concatenate statements? vertices = [vertices, left_side_points, right_side_points] faces = [faces, left_side_trias, right_side_trias] vertex_values = [vertex_values, left_side_colors, right_side_colors] - return cls(vertices, faces, vertex_values=vertex_values, resolution=resolution) + return cls(vertices, faces, vertex_values=vertex_values) diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index 440c6787..c4ad81da 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -18,11 +18,13 @@ from typing import get_args import numpy as np +from nibabel.freesurfer.mghformat import MGHHeader +from numpy import typing as npt import FastSurferCNN.utils.logging as logging -from CorpusCallosum.data.constants import CC_LABEL, FSAVERAGE_MIDDLE, SUBSEGMENT_LABELS +from CorpusCallosum.data.constants import CC_LABEL, SUBSEGMENT_LABELS from CorpusCallosum.shape.contour import CCContour -from CorpusCallosum.shape.endpoint_heuristic import get_endpoints +from CorpusCallosum.shape.endpoint_heuristic import connect_diagonally_connected_components from CorpusCallosum.shape.mesh import CCMesh from CorpusCallosum.shape.metrics import calculate_cc_index from CorpusCallosum.shape.subsegment_contour import ( @@ -33,11 +35,17 @@ subsegment_midline_orthogonal, transform_to_acpc_standard, ) -from CorpusCallosum.shape.thickness import cc_thickness, convert_to_ras -from CorpusCallosum.utils.types import CCMeasuresDict, ContourThickness, Points2dType, SliceSelection, SubdivisionMethod +from CorpusCallosum.shape.thickness import cc_thickness +from CorpusCallosum.utils.types import ( + CCMeasuresDict, + ContourThickness, + Points2dType, + Polygon2dType, + SliceSelection, + SubdivisionMethod, +) from CorpusCallosum.utils.visualization import plot_contours -from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Mask2d, ScalarType, Shape2d, Shape3d, Vector2d, nibabelImage, \ - nibabelHeader +from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Mask2d, Shape2d, Shape3d, Vector2d from FastSurferCNN.utils.common import SubjectDirectory, update_docstring from FastSurferCNN.utils.parallel import process_executor, thread_executor @@ -50,37 +58,43 @@ LIA_ORIENTATION[2,1] = -1 -def create_sag_slice_vox2vox(slice_idx: int, fsaverage_middle: float) -> AffineMatrix4x4: - """Create slice-specific slice to full affine transformation matrix. - - Returns a volume to slice in volume affine. +def offset_affine(offset: npt.ArrayLike) -> AffineMatrix4x4: + """Generate an affine transformation matrix that only constitutes an offset (vector). Parameters ---------- - slice_idx : int - Index of the slice to transform. - fsaverage_middle : float - Reference middle slice index in fsaverage space. + offset : array_like + A 3-dimensional offset vector (shape (3,)) to offset with. Returns ------- np.ndarray - Modified 4x4 affine transformation matrix for the specific slice. + Modified 4x4 affine transformation matrix with the specific offset. + + Raises + ------ + TypeError + If offset is not a """ - slice2full_vox2vox: AffineMatrix4x4 = np.eye(4, dtype=float) - slice2full_vox2vox[0, 3] = -fsaverage_middle + slice_idx - return slice2full_vox2vox + _offset = np.asarray(offset) + if not isinstance(_offset, np.ndarray) or _offset.shape != (3,): + raise TypeError("offset must convert to a ndarray of shape (3,)!") + vox2vox: AffineMatrix4x4 = np.eye(4, dtype=float) + vox2vox[0:3, 3] = _offset + return vox2vox @update_docstring(SubdivisionMethod=str(get_args(SubdivisionMethod))[1:-1]) def recon_cc_surf_measures_multi( segmentation: np.ndarray[Shape3d, np.dtype[np.int_]], - upright_affine_header: tuple[AffineMatrix4x4, nibabelHeader], slice_selection: SliceSelection, + upright_header: MGHHeader, + fsavg2midslab_vox2vox: AffineMatrix4x4, fsavg_vox2ras: AffineMatrix4x4, + orig2fsavg_vox2vox: AffineMatrix4x4, midslices: Image3d, - ac_coords: Vector2d, - pc_coords: Vector2d, + ac_coords_vox: Vector2d, + pc_coords_vox: Vector2d, num_thickness_points: int, subdivisions: list[float], subdivision_method: SubdivisionMethod, @@ -93,19 +107,23 @@ def recon_cc_surf_measures_multi( Parameters ---------- segmentation : np.ndarray - 3D segmentation array. - upright_affine_header : tuple[AffineMatrix4x4, nibabelHeader] - A tuple of the vox2ras matrix and the header of the upright image. + 3D segmentation array in LIA orientation. slice_selection : str Which slices to process ('middle', 'all', or slice number). + upright_header : MGHHeader + The header of the upright image. + fsavg2midslab_vox2vox : AffineMatrix4x4 + The vox2vox transformation matrix from fsaverage (upright) space to the segmentation slab. fsavg_vox2ras : np.ndarray Base affine transformation matrix (fsaverage, upright space). + orig2fsavg_vox2vox : AffineMatrix4x4 + The transformation matrix from orig to fsaverage in voxel space. midslices : np.ndarray Array of mid-sagittal slices. - ac_coords : np.ndarray - Anterior commissure coordinates. - pc_coords : np.ndarray - Posterior commissure coordinates. + ac_coords_vox : np.ndarray + AC voxel coordinates with shape (2,) containing its [y,x] positions. + pc_coords_vox : np.ndarray + PC voxel coordinates with shape (2,) containing its [y,x] positions. num_thickness_points : int Number of points for thickness estimation. subdivisions : list[float] @@ -138,8 +156,8 @@ def recon_cc_surf_measures_multi( _each_slice = partial( recon_cc_surf_measure, segmentation, - ac_coords=ac_coords, - pc_coords=pc_coords, + ac_coords_vox=ac_coords_vox, + pc_coords_vox=pc_coords_vox, num_thickness_points=num_thickness_points, subdivisions=subdivisions, subdivision_method=subdivision_method, @@ -161,15 +179,21 @@ def recon_cc_surf_measures_multi( num_slices = 1 slices_to_recon = [int(slice_selection)] - _gen_fsavg2slice_vox2vox = partial(create_sag_slice_vox2vox, fsaverage_middle=FSAVERAGE_MIDDLE) - per_slice_vox2ras = fsavg_vox2ras @ np.stack(list(map(_gen_fsavg2slice_vox2vox, slices_to_recon)), axis=0) + def _gen_slice2slab_vox2vox(_slice_idx: int) -> AffineMatrix4x4: + # The slice_idx offset must be negative, because we are going from left to right. + return offset_affine([_slice_idx, 0, 0]) + + fsavg_midslab_vox2ras = fsavg_vox2ras @ np.linalg.inv(fsavg2midslab_vox2vox) + per_slice_vox2ras = fsavg_midslab_vox2ras @ np.stack(list(map(_gen_slice2slab_vox2vox, slices_to_recon)), axis=0) per_slice_recon = process_executor().map(_each_slice, slices_to_recon, per_slice_vox2ras, chunksize=1) cc_contours = [] run = thread_executor().submit wants_output = subject_dir.has_attribute - for i, (slice_idx, _results) in enumerate(zip(slices_to_recon, per_slice_recon, strict=True)): + output_path = subject_dir.filename_by_attribute + slice_iterator = zip(slices_to_recon, per_slice_vox2ras, per_slice_recon, strict=True) + for i, (slice_idx, this_slice_vox2ras, _results) in enumerate(slice_iterator): progress = f" ({i+1} of {num_slices})" if num_slices > 1 else "" logger.info(f"Calculating CC measurements for slice {slice_idx+1}{progress}") # unpack values from _results @@ -179,7 +203,8 @@ def recon_cc_surf_measures_multi( contour_in_as_space: Points2dType = contour_in_as_space_and_thickness[:, :2] thickness_values: np.ndarray[tuple[int], np.dtype[np.float_]] = contour_in_as_space_and_thickness[:, 2] - cc_contours.append(CCContour(contour_in_as_space, thickness_values, endpoint_idxs, resolution=vox_size[0])) + z_value = this_slice_vox2ras[0, 3] + cc_contours.append(CCContour(contour_in_as_space, thickness_values, endpoint_idxs, z_position=z_value)) if cc_measures is None: # this should not happen, but just in case logger.warning(f"Slice index {slice_idx+1}{progress} returned result `None`") @@ -188,7 +213,7 @@ def recon_cc_surf_measures_multi( is_debug = logger.getEffectiveLevel() <= logging.DEBUG is_midslice = slice_idx == num_slices // 2 if wants_output("cc_qc_image") and (is_debug or is_midslice): - qc_imgs: list[Path] = [subject_dir.filename_by_attribute("cc_qc_image")] + qc_imgs: list[Path] = [output_path("cc_qc_image")] if is_debug: qc_slice_img = qc_imgs[0].with_suffix(f".slice_{slice_idx}.png") qc_imgs = (qc_imgs if is_midslice else []) + [qc_slice_img] @@ -204,16 +229,15 @@ def recon_cc_surf_measures_multi( midline_equidistant=cc_measures["midline_equidistant"], levelpaths=cc_measures["levelpaths"], output_path=qc_imgs, - ac_coords=ac_coords, - pc_coords=pc_coords, + ac_coords=ac_coords_vox, + pc_coords=pc_coords_vox, vox_size=vox_size, title=f"CC Subsegmentation by {subdivision_method} (Slice {slice_idx + 1})", ) ) - if wants_output("save_template_dir"): - template_dir = subject_dir.filename_by_attribute("save_template_dir") + template_dir = output_path("save_template_dir") # ensure directory exists template_dir.mkdir(parents=True, exist_ok=True) logger.info("Saving template files (contours.txt, thickness_values.txt, " @@ -230,41 +254,49 @@ def recon_cc_surf_measures_multi( _cc_contours = thread_executor().map(_resample_thickness, cc_contours) cc_mesh = CCMesh.from_contours(list(_cc_contours), smooth=1) if wants_output("cc_html"): - logger.info(f"Saving CC 3D visualization to {subject_dir.filename_by_attribute('cc_html')}") - io_futures.append(run( - cc_mesh.plot_mesh,output_path=subject_dir.filename_by_attribute("cc_html")), - ) + logger.info(f"Saving CC 3D visualization to {output_path('cc_html')}") + io_futures.append(run(cc_mesh.plot_mesh, output_path=output_path("cc_html"))) if wants_output("cc_mesh"): - vtk_file_path = subject_dir.filename_by_attribute("cc_mesh") + vtk_file_path = output_path("cc_mesh") logger.info(f"Saving vtk file to {vtk_file_path}") io_futures.append(run(cc_mesh.write_vtk, vtk_file_path)) - if wants_output("cc_thickness_overlay"): - overlay_file_path = subject_dir.filename_by_attribute("cc_thickness_overlay") + if wants_output("cc_thickness_overlay") and not wants_output("cc_thickness_image"): + overlay_file_path = output_path("cc_thickness_overlay") logger.info(f"Saving overlay file to {overlay_file_path}") io_futures.append(run(cc_mesh.write_morph_data, overlay_file_path)) if any(wants_output(f"cc_{n}") for n in ("thickness_image", "cc_surf")): import nibabel as nib - upright_img = nib.MGHImage(np.zeros(() * 3, dtype=np.uint8)) - - # the mesh is generated in upright coordinates, so we need to also transform to orig coordinates - cc_mesh = cc_mesh.to_fs_coordinates(lr_offset=FSAVERAGE_MIDDLE / vox_size[0]) - if wants_output("cc_surf"): - surf_file_path = subject_dir.filename_by_attribute("cc_surf") - logger.info(f"Saving surf file to {surf_file_path}") - cc_mesh.write_fssurf(str(surf_file_path), image=upright_img) - # io_futures.append(run(cc_mesh.write_fssurf, str(surf_file_path), image=orig)) - - if wants_output("cc_thickness_image"): - thickness_image_path = subject_dir.filename_by_attribute("cc_thickness_image") - logger.info(f"Saving thickness image to {thickness_image_path}") - cc_mesh.snap_cc_picture(thickness_image_path, ref_image=upright_img) - - if not slice_cc_measures: - logger.error("Error: No valid slices were found for postprocessing") - raise ValueError("No valid slices were found for postprocessing") + up_data: Image3d[np.uint8] = np.empty(upright_header["dims"][:3], dtype=upright_header.get_data_dtype()) + upright_img = nib.MGHImage(up_data, fsavg_vox2ras, upright_header) + # the mesh is generated in upright coordinates, so we need to also transform to orig coordinates + # FIXME: this is currently not in RAS coordinates! + + # Mesh is fsavg_midplane (RAS); we need to transform to voxel coordinates + # fsavg ras is also on the midslice, so this is fine and we multiply in the IA and SP offsets + cc_mesh = cc_mesh.to_vox_coordinates(mesh_ras2vox=np.linalg.inv(fsavg_vox2ras @ orig2fsavg_vox2vox)) + #FIXME: to_fs_coordinate needs to transform from upright to + if wants_output("cc_thickness_image"): + # this will also write overlay and surface + thickness_image_path = output_path("cc_thickness_image") + logger.info(f"Saving thickness image to {thickness_image_path}") + kwargs = { + "fssurf_file": output_path("cc_surf") if wants_output("cc_surf") else None, + "overlay_file": output_path("cc_thickness_overlay") + if wants_output("cc_thickness_overlay") else None, + "ref_image": upright_img, + } + cc_mesh.snap_cc_picture(thickness_image_path, **kwargs) + elif wants_output("cc_surf"): + surf_file_path = output_path("cc_surf") + logger.info(f"Saving surf file to {surf_file_path}") + io_futures.append(run(cc_mesh.write_fssurf, str(surf_file_path), image=upright_img)) + + if not slice_cc_measures: + logger.error("Error: No valid slices were found for postprocessing") + raise ValueError("No valid slices were found for postprocessing") return slice_cc_measures, io_futures @@ -279,9 +311,9 @@ def _resample_thickness(contour: CCContour) -> CCContour: def recon_cc_surf_measure( segmentation: np.ndarray[Shape2d, np.dtype[np.int_]], slice_idx: int, - affine: AffineMatrix4x4, - ac_coords: Vector2d, - pc_coords: Vector2d, + slice_lia_vox2midslice_ras: AffineMatrix4x4, + ac_coords_vox: Vector2d, + pc_coords_vox: Vector2d, num_thickness_points: int, subdivisions: list[float], subdivision_method: SubdivisionMethod, @@ -296,12 +328,12 @@ def recon_cc_surf_measure( 3D segmentation array. slice_idx : int Index of the slice to process. - affine : AffineMatrix4x4 + slice_lia_vox2midslice_ras : AffineMatrix4x4 4x4 affine transformation matrix. - ac_coords : np.ndarray of shape (2,) and type float - Anterior commissure coordinates. - pc_coords : np.ndarray of shape (2,) and type float - Posterior commissure coordinates. + ac_coords_vox : np.ndarray + AC voxel coordinates with shape (2,) containing its [y,x] positions. + pc_coords_vox : np.ndarray + PC voxel coordinates with shape (2,) containing its [y,x] positions. num_thickness_points : int Number of points for thickness estimation. subdivisions : list[float] @@ -318,7 +350,7 @@ def recon_cc_surf_measure( measures : CCMeasuresDict Dictionary containing measurements if successful. contour_with_thickness : np.ndarray - Contour points with thickness information, shape (3, N) for [x, y, thickness]. + Contour points with thickness information in fsavg_midslice_ras space, shape (3, N) for [x, y, thickness]. endpoint_indices : pair of ints Indices of the anterior and posterior endpoints on the contour. @@ -339,31 +371,58 @@ def recon_cc_surf_measure( cc_mask_slice: Mask2d = np.equal(segmentation[slice_idx], CC_LABEL) if not np.any(cc_mask_slice): raise ValueError(f"No CC found in slice {slice_idx}") - contour, endpoint_idxs = get_endpoints( - cc_mask_slice, - ac_coords, - pc_coords, - (vox_size[1], vox_size[2]), - return_coordinates=False, - contour_smoothing=contour_smoothing, + # clean up cc mask + cc_mask = connect_diagonally_connected_components(cc_mask_slice) + # create a CCContour from the cc_mask and transform to RAS coordinates + # - R coordinate is stored in _contour.z_position + # - AS coordinates are stored in _contour.points + _contour = CCContour.from_mask_and_acpc( + cc_mask, ac_coords_vox, pc_coords_vox, + slice_vox2ras=slice_lia_vox2midslice_ras, contour_smoothing=contour_smoothing, ) - contour_ras = convert_to_ras(contour, affine) - endpoint_idxs: tuple[int, int] + contour_as = _contour.points.T + endpoint_idxs = _contour.endpoint_idxs + # FIXME: could probably also use _contour.create_levelpaths here, but that does not currently return all values + # levelpaths, thickness = _contour.create_levelpaths(num_thickness_points) + + # FIXME: If we create CCContour objects here already (as we can), we should probably return that instead of the + # contour_with_thickness value (as the CCContour has all that information as well) + + # # find_contour_and_endpoints extracts the contour and finds ac and pc endpoints for shape analysis + # # contour is in IA voxel coordinates + # contour, endpoint_idxs = find_contour_and_endpoints( + # cc_mask_slice, + # ac_coords_vox, + # pc_coords_vox, + # (vox_size[1], vox_size[2]), + # return_coordinates=False, + # contour_smoothing=contour_smoothing, + # ) + # # contour_ras uses coordinates in the fsavg_midslice_ras coordinate system, now re-order/flip slice_ia + # # coordinates to fsavg_ras coordinates. + # #FIXME: double-check the sign of the z_offset (lr) here, currently starts positive for first slice + # offsets = np.asarray([-vox_size[0] * (slice_idx - segmentation.shape[0] // 2), 0, 0, 1]) + # affine = np.concatenate([slice_lia_vox2midslice_ras[:, :3], offsets[:, None]], axis=1) + # # convert to fsavg_ras coordinates (which are mid-slice-based) + # contour_as = (slice_lia_vox2midslice_ras @ np.append(contour, 1, axis=0))[1:3] + contour_with_thickness: ContourThickness - midline_len, thickness, curvature, midline_equi, levelpaths, contour_with_thickness, endpoint_idxs = cc_thickness( - contour_ras[1:].T, - endpoint_idxs, - n_points=num_thickness_points, - ) + # cc_thickness wants contour to be in midslice_ras coordinates, i.e. millimeter distances on the respective slice. + midline_len, thickness, curvature, midline_equi, levelpaths, contour_with_thickness, endpoint_idxs = \ + cc_thickness( + contour_as.T, + endpoint_idxs, + n_points=num_thickness_points, + ) # thickness values in contour_with_thickness is not equally sampled, different shape # to compute length of paths: diff between consecutive points (N-1, 2) => norm (N-1,) => sum (1,) thickness_profile = np.stack([np.sum(np.linalg.norm(np.diff(x[:, :2], axis=0), axis=1)) for x in levelpaths]) - acpc_contour_coords_ras = contour_ras[:, list(endpoint_idxs)].T + acpc_contour_coords_as = contour_as[:, list(endpoint_idxs)].T contour_in_acpc_space, ac_pt_acpc, pc_pt_acpc, rotate_back_acpc = transform_to_acpc_standard( - contour_ras[1:], - *acpc_contour_coords_ras[:, 1:], + contour_as, + *acpc_contour_coords_as, ) cc_index = calculate_cc_index(contour_in_acpc_space) @@ -371,8 +430,8 @@ def recon_cc_surf_measure( split_contours: ContourList if subdivision_method == "shape": _subdivisions = np.asarray(subdivisions) - areas, split_contours = subsegment_midline_orthogonal(midline_equi, _subdivisions, contour_ras[1:], plot=False) - split_contours = [transform_to_acpc_standard(split_contour, *acpc_contour_coords_ras[:, 1:])[0] + areas, split_contours = subsegment_midline_orthogonal(midline_equi, _subdivisions, contour_as, plot=False) + split_contours = [transform_to_acpc_standard(split_contour, *acpc_contour_coords_as)[0] for split_contour in split_contours] elif subdivision_method == "vertical": areas, split_contours = subdivide_contour(contour_in_acpc_space, subdivisions, plot=False) @@ -394,7 +453,7 @@ def recon_cc_surf_measure( raise ValueError(f"Invalid subdivision method {subdivision_method}") total_area = np.sum(areas) - total_perimeter = np.sum(np.sqrt(np.sum((np.diff(contour_ras[:, 1:], axis=0))**2, axis=1))) + total_perimeter = np.sum(np.sqrt(np.sum((np.diff(contour_as, axis=0))**2, axis=1))) circularity = 4 * np.pi * total_area / (total_perimeter**2) # Transform split contours back to original space @@ -418,20 +477,17 @@ def recon_cc_surf_measure( return measures, contour_with_thickness, endpoint_idxs -def vectorized_line_test( - coords_x: np.ndarray[tuple[int], np.dtype[ScalarType]], - coords_y: np.ndarray[tuple[int], np.dtype[ScalarType]], +def test_right_of_line( + coords: Polygon2dType, line_start: Vector2d, line_end: Vector2d, ) -> np.ndarray[tuple[int], np.dtype[np.bool_]]: - """Vectorized version of point_relative_to_line for arrays of points. + """Test whether points in coords are to the right of the line (line_start->line_end). Parameters ---------- - coords_x : np.ndarray - Array of x coordinates. - coords_y : np.ndarray - Array of y coordinates. + coords : np.ndarray + Array of coordinates of shape (2, N). line_start : array-like [x, y] coordinates of line start point. line_end : array-like @@ -442,16 +498,14 @@ def vectorized_line_test( np.ndarray Boolean array where True means point is to the left of the line. """ - # FIXME: rename this function to something more indicative # Vector from line_start to line_end line_vec = np.array(line_end) - np.array(line_start) # Vectors from line_start to all points (vectorized) - point_vec_x = coords_x - line_start[0] - point_vec_y = coords_y - line_start[1] - + point_vec = coords - np.expand_dims(line_start, axis=list(range(1, coords.ndim))) + # Cross product (vectorized): positive means point is to the left of the line - cross_products = line_vec[0] * point_vec_y - line_vec[1] * point_vec_x + cross_products = line_vec[0] * point_vec[1] - line_vec[1] * point_vec[0] return cross_products > 0 @@ -546,7 +600,7 @@ def make_subdivision_mask( - Updates labels for those points. """ - # unique contour points are the points where sub-division lines were inserted + # unique_contour_points are the points where sub-division lines were inserted unique_contour_points: list[Points2dType] = get_unique_contour_points(split_contours) # shape (N, 2) subdivision_segments = unique_contour_points[1:] @@ -556,7 +610,7 @@ def make_subdivision_mask( # Create coordinate grids for all points in the slice rows, cols = slice_shape - y_coords, x_coords = np.mgrid[0:rows, 0:cols] + coords = np.array(np.mgrid[0:rows, 0:cols])[[1, 0]] cc_subsegment_lut_anterior_to_posterior = SUBSEGMENT_LABELS.copy() cc_subsegment_lut_anterior_to_posterior.reverse() @@ -564,15 +618,17 @@ def make_subdivision_mask( # Initialize with first segment label subdivision_mask = np.full(slice_shape, cc_subsegment_lut_anterior_to_posterior[0], dtype=np.int32) - # Process each subdivision line + # Process each subdivision line, subdivision_segments has for each division line the two points that are on the + # contour and divide the subsegments for segment_idx, segment_points in enumerate(subdivision_segments): - # FIXME: names for line_start and line_end? + # line_start and line_end are the intersection points of the CC subsegmentation boundary and the contour line + # --> find all voxels posterior to the line in question line_start: Vector2d = segment_points[0] / vox_size line_end: Vector2d = segment_points[-1] / vox_size - # Vectorized test: find all points to the right of this line - # FIXME: line defined by what? Is this inside the polygon or the line from line_start to line_end? - points_right_of_line = vectorized_line_test(x_coords, y_coords, line_start, line_end) + # Vectorized test: find all points to the right of line (line_start->line_end) + # right_of_line == posterior to line + points_right_of_line = test_right_of_line(coords, line_start, line_end) # All points to the right of this line belong to the next segment or beyond subdivision_mask[points_right_of_line] = cc_subsegment_lut_anterior_to_posterior[segment_idx + 1] diff --git a/CorpusCallosum/shape/subsegment_contour.py b/CorpusCallosum/shape/subsegment_contour.py index b730cb51..ea345744 100644 --- a/CorpusCallosum/shape/subsegment_contour.py +++ b/CorpusCallosum/shape/subsegment_contour.py @@ -13,17 +13,15 @@ # limitations under the License. from collections.abc import Callable -from typing import TYPE_CHECKING, Literal +from typing import Literal import matplotlib.pyplot as plt import numpy as np from scipy.spatial import ConvexHull from CorpusCallosum.utils.types import ContourList, Points2dType, Polygon2dType, Polygon3dType -from FastSurferCNN.utils import Mask2d, Mask3d, ScalarType, Vector2d, nibabelImage +from FastSurferCNN.utils import ScalarType, Vector2d -if TYPE_CHECKING: - import pandas as pd def minimum_bounding_rectangle(points: Points2dType) -> np.ndarray[tuple[Literal[4], Literal[2]], np.dtype[ScalarType]]: """Find the smallest bounding rectangle for a set of points. @@ -144,6 +142,9 @@ def subsegment_midline_orthogonal( # get points after midline length of splits # get vertex closest to midline end + + # FIXME: should this not always be the posterior endpoint index? Can we not standardize this even earlier, and then + # pull this into CCContour.from_mask_and_appc? midline_end_idx = np.argmin(np.linalg.norm(contour.T - midline[-1], axis=1)) # roll contour start to midline end contour = np.roll(contour, -midline_end_idx, axis=1) @@ -850,48 +851,6 @@ def rotate_back(x: Polygon2dType) -> Polygon2dType: return contour_acpc, np.array([0, 0], dtype=float), np.array([-ac_pc_dist, 0], dtype=float), rotate_back -def preprocess_cc(cc_label_nib: nibabelImage, paths_csv: "pd.DataFrame", subj_id: str) \ - -> tuple[Mask2d, Vector2d, Vector2d]: - """Preprocess corpus callosum mask and extract AC/PC coordinates. - - Parameters - ---------- - cc_label_nib : nibabel.Nifti1Image - NIfTI image containing corpus callosum segmentation. - paths_csv : pd.DataFrame - DataFrame containing AC and PC coordinates. - subj_id : str - Subject ID to look up in paths_csv. - - Returns - ------- - cc_mask : np.ndarray - Binary mask of corpus callosum. - AC_2d : np.ndarray - 2D coordinates of anterior commissure. - PC_2d : np.ndarray - 2D coordinates of posterior commissure. - - """ - #FIXME: this function is not used anywhere - _cc_mask: Mask3d = np.asarray(cc_label_nib.dataobj) == 192 - cc_mask: Mask2d = _cc_mask[_cc_mask.shape[0] // 2] - - posterior_commisure_center = paths_csv.loc[subj_id, "PC_center_r": "PC_center_s"].to_numpy().astype(float) - anterior_commisure_center = paths_csv.loc[subj_id, "AC_center_r": "AC_center_s"].to_numpy().astype(float) - - # adjust LR from label coordinates to orig_up coordinates - posterior_commisure_center[0] = 128 - anterior_commisure_center[0] = 128 - - # orientation I, A - # rotate image so anterior and posterior commisure are horizontal - ac_2d = anterior_commisure_center[1:] - pc_2d = posterior_commisure_center[1:] - - return cc_mask, ac_2d, pc_2d - - def get_primary_eigenvector(contour_ras: Polygon2dType) -> tuple[Vector2d, Vector2d]: """Calculate primary eigenvector of contour points using PCA. diff --git a/CorpusCallosum/shape/thickness.py b/CorpusCallosum/shape/thickness.py index 0303d1fe..7c1c46d3 100644 --- a/CorpusCallosum/shape/thickness.py +++ b/CorpusCallosum/shape/thickness.py @@ -46,94 +46,6 @@ def compute_curvature(path: Points2dType) -> np.ndarray[tuple[int], np.dtype[np. return angle_diffs -@overload -def convert_to_ras(contour: np.ndarray, vox2ras_matrix: np.ndarray, get_parameters: Literal[False] = False) \ - -> np.ndarray: ... - -@overload -def convert_to_ras(contour: np.ndarray, vox2ras_matrix: np.ndarray, get_parameters: Literal[True]) \ - -> tuple[np.ndarray, bool, bool, bool]: ... - - -def convert_to_ras( - contour: np.ndarray, - vox2ras_matrix: np.ndarray, - return_parameters: bool = False -): - """Convert contour coordinates from voxel space to RAS space. - - Parameters - ---------- - contour : np.ndarray - Array of shape (2, N) or (3, N) containing contour coordinates. - vox2ras_matrix : np.ndarray - 4x4 voxel to RAS transformation matrix. - return_parameters : bool, default=False - If True, return additional transformation parameters (see below). - - Returns - ------- - contour : np.ndarray - Transformed contour coordinates of shape (3, N). - anterior_reversed : bool - Only if return_parameters is True, whether anterior axis was reversed. - superior_reversed : bool - Only if return_parameters is True, whether superior axis was reversed. - swap_axes : bool - Only if return_parameters is True, whether axes were swapped. - """ - # converting to AS (no left-right dimension), out of plane movement is ignored, - # so we only do scaling, axes swapping and flipping - no rotation - # translation is ignored - if contour.shape[0] == 2: - # get only axis swaps from the rotation part of the vox2ras matrix - axis_swaps = np.round(vox2ras_matrix[:3, :3], 0) - permutation = np.argwhere(axis_swaps != 0)[:, 1] - assert len(permutation) == 3 - - idx_superior = np.argwhere(permutation == 2) - idx_anterior = np.argwhere(permutation == 1) - - # swap axes if indicated from vox2ras - if swap_axes := idx_anterior > idx_superior: - # swap anterior and superior - contour = contour[[1, 0]] - - # determine if axis were reversed - superior_reversed = np.any(axis_swaps[2, :] == -1) - anterior_reversed = np.any(axis_swaps[1, :] == -1) - - # flip axes if necessary - if superior_reversed: - contour[1] = -contour[1] - if anterior_reversed: - contour[0] = -contour[0] - - # get scaling by getting length of three column vectors - scaling = np.linalg.norm(vox2ras_matrix[:3, :3], axis=0) - - # voxel * vox_size = mm - contour = (contour.T * scaling[1:]).T - - # append a 0-R coordinate - contour = np.concatenate([np.zeros((1, contour.shape[1])), contour], axis=0) - - if return_parameters: - return contour, anterior_reversed, superior_reversed, swap_axes - else: - return contour - - # Add a third dimension (z) with 0 and a fourth dimension (homogeneous coordinate) with 1 - elif contour.shape[0] == 3: - contour_homogeneous = np.vstack([contour, np.ones(contour.shape[1])]) - - # Apply the transformation - contour = (vox2ras_matrix @ contour_homogeneous)[:3, :] - return contour - else: - raise ValueError("Invalid shape of contour") - - def set_contour_zero_idx(contour, idx, anterior_endpoint_idx, posterior_endpoint_idx): """Roll contour points to set a new zero index, while keeping track of CC endpoints. @@ -274,7 +186,7 @@ def make_mesh_from_contour( max_volume: float = 0.5, min_angle: float = 25, verbose: bool = False -) -> tuple[np.ndarray, np.ndarray]: +) -> tuple[Points2dType[np.float_], np.ndarray[tuple[int, Literal[3]], np.dtype[np.int_]]]: """Create a triangular mesh from a 2D contour. Parameters @@ -290,9 +202,10 @@ def make_mesh_from_contour( Returns ------- - tuple[np.ndarray, np.ndarray] - - mesh_points : Array of shape (M, 2) containing mesh vertices. - - mesh_trias : Array of shape (K, 3) containing triangle indices. + mesh_points : np.ndarray + Array of shape (M, 2) containing mesh vertices. + mesh_trias : np.ndarray + Array of shape (K, 3) containing triangle indices. Notes ----- @@ -304,13 +217,13 @@ def make_mesh_from_contour( # use meshpy to create mesh info = triangle.MeshInfo() - info.set_points(contour_2d) + info.set_points(contour_2d) # needs to be (N, D) info.set_facets(facets) # NOTE: crashes if contour has duplicate points !! mesh = triangle.build(info, max_volume=max_volume, min_angle=min_angle, verbose=verbose) - mesh_points = np.array(mesh.points) - mesh_trias = np.array(mesh.elements) + mesh_points: Points2dType[np.float_] = np.array(mesh.points, dtype=float) + mesh_trias: np.ndarray[tuple[int, Literal[3]], np.dtype[np.int_]] = np.array(mesh.elements, dtype=int) return mesh_points, mesh_trias diff --git a/CorpusCallosum/utils/mapping_helpers.py b/CorpusCallosum/utils/mapping_helpers.py index 8dd72e5d..eacdbd79 100644 --- a/CorpusCallosum/utils/mapping_helpers.py +++ b/CorpusCallosum/utils/mapping_helpers.py @@ -4,7 +4,6 @@ import nibabel as nib import numpy as np import SimpleITK as sitk -from numpy import typing as npt from scipy.ndimage import affine_transform from CorpusCallosum.data.constants import CC_LABEL, FORNIX_LABEL @@ -125,6 +124,8 @@ def apply_transform_to_pt(pts: Vector3d | Polygon3dType, T: AffineMatrix4x4, inv np.ndarray Transformed point coordinates, shape (3,) or (3, N). """ + # FIXME: This function is very similar to nibabel.affines.apply_affine, reduce duplication. + # Differences: Here, pts dimensions are (3,) or (3, N), in apply_affine, they are (..., D-1) for DxD affines. if inv: T = np.linalg.inv(T) @@ -135,7 +136,7 @@ def apply_transform_to_pt(pts: Vector3d | Polygon3dType, T: AffineMatrix4x4, inv def calc_mapping_to_standard_space( - orig: "nib.Nifti1Image", + orig: nibabelImage, ac_coords_3d: Vector3d, pc_coords_3d: Vector3d, orig_fsaverage_vox2vox: AffineMatrix4x4, @@ -144,7 +145,7 @@ def calc_mapping_to_standard_space( Parameters ---------- - orig : nib.Nifti1Image + orig : nibabelImage Original image. ac_coords_3d : np.ndarray AC coordinates in 3D space. @@ -214,27 +215,27 @@ def calc_mapping_to_standard_space( def apply_transform_to_volume( orig_image: nibabelImage, - vox2vox: AffineMatrix4x4, - affine: AffineMatrix4x4, + interp_vox2vox: AffineMatrix4x4, + save_vox2ras: AffineMatrix4x4 | None = None, header: nib.freesurfer.mghformat.MGHHeader | None = None, output_path: str | Path | None = None, output_size: np.ndarray | None = None, order: int = 1 -) -> npt.NDArray[float]: +) -> Image3d[np.float_]: """Apply transformation to a volume and save the result. Parameters ---------- orig_image : nibabelImage Input volume. - vox2vox : np.ndarray + interp_vox2vox : np.ndarray Transformation matrix to apply to the data, this is from input-to-output space. - affine : AffineMatrix4x4, optional + save_vox2ras : AffineMatrix4x4, optional The vox2ras matrix of the output image, only relevant if output_path is given. header : nibabelHeader, optional Header for the output image, only relevant if output_path is given, if None will default to orig_image header. output_path : str or Path, optional - If output_path is provided, saves the result under this path. + If output_path is provided, saves the result under this path using the dtype of header (or orig_image). output_size : np.ndarray, optional Size of output volume, uses input size by default `None`. order : int, default=1 @@ -242,8 +243,8 @@ def apply_transform_to_volume( Returns ------- - npt.NDArray[float] - Transformed volume data. + np.ndarray + Transformed volume data of shape `output_size` and type float. Notes ----- @@ -254,13 +255,21 @@ def apply_transform_to_volume( output_size = np.array(orig_image.shape) if header is None: header = orig_image.header + if save_vox2ras is None: + save_vox2ras = orig_image.affine @ interp_vox2vox # transform / resample the volume with vox2vox, note this needs to be the inverse of input2output vox2vox! # affine_transform definition is: input_coord = matrix @ output_coord + offset ( == MATRIX_HOM @ output_coord_hom) # --> output_coord = inv(matrix) @ (input_coord - offset) ( == inv(MATRIX_HOM) @ input_coord_hom) - resampled = affine_transform(orig_image.get_fdata(), np.linalg.inv(vox2vox), output_shape=output_size, order=order) + resampled = affine_transform( + orig_image.get_fdata(), + np.linalg.inv(interp_vox2vox), + output_shape=output_size, + order=order, + ) if output_path is not None: logger.info(f"Saving transformed volume to {output_path}") - nib.save(nib.MGHImage(resampled.astype(orig_image.get_data_dtype()), affine, header), output_path) + resampled_typecast = resampled.astype((header if header else orig_image).get_data_dtype()) + nib.save(nib.MGHImage(resampled_typecast, save_vox2ras, header), output_path) return resampled @@ -287,10 +296,7 @@ def make_affine(simpleITKImage: sitk.Image) -> AffineMatrix4x4: # get affine transform in LPS c = [simpleITKImage.TransformContinuousIndexToPhysicalPoint(p) for p in np.eye(4)[:, :3]] c = np.array(c) - affine = np.concatenate( - [np.concatenate([c[0:3] - c[3:], c[3:]], axis=0), [[0.0], [0.0], [0.0], [1.0]]], - axis=1, - ) + affine = np.append(np.append(c[0:3] - c[3:], c[3:], axis=0), np.eye(4)[3], axis=1) affine = np.transpose(affine) # convert to RAS to match nibabel affine = np.matmul(np.diag([-1.0, -1.0, 1.0, 1.0]), affine) From f3dcdcaa83f015659021b60df16756032608713b Mon Sep 17 00:00:00 2001 From: ClePol Date: Mon, 22 Dec 2025 15:41:12 +0100 Subject: [PATCH 53/68] fixed label consolidation --- CorpusCallosum/paint_cc_into_pred.py | 387 ++++++++++++++++++--------- 1 file changed, 258 insertions(+), 129 deletions(-) diff --git a/CorpusCallosum/paint_cc_into_pred.py b/CorpusCallosum/paint_cc_into_pred.py index 420abaea..39d0e61a 100644 --- a/CorpusCallosum/paint_cc_into_pred.py +++ b/CorpusCallosum/paint_cc_into_pred.py @@ -132,132 +132,262 @@ def paint_in_cc(pred: npt.NDArray[np.int_], The CC labels (251-255) from aseg_cc are copied into pred. """ cc_mask = mask_in_array(aseg_cc, SUBSEGMENT_LABELS) + + # Count what's being replaced + replaced_labels = pred[cc_mask] + num_wm_replaced = np.sum((replaced_labels == 2) | (replaced_labels == 41)) + num_other_replaced = np.sum((replaced_labels != 0) & (replaced_labels != 2) & (replaced_labels != 41)) + num_background_replaced = np.sum(replaced_labels == 0) + + logger.info(f"Painting CC: {np.sum(cc_mask)} voxels (replacing {num_wm_replaced} WM, " + f"{num_background_replaced} background, {num_other_replaced} other)") + pred[cc_mask] = aseg_cc[cc_mask] return pred +def _fill_gaps_in_direction( + corrected_pred: npt.NDArray[np.int_], + potential_fill: npt.NDArray[np.bool_], + source_binary: npt.NDArray[np.bool_], + target_binary: npt.NDArray[np.bool_], + x_slice: int, + direction: str, + max_gap_voxels: int, + fillable_labels: set[int] +) -> int: + """Fill gaps between source and target masks in a specific direction. + + Parameters + ---------- + corrected_pred : npt.NDArray[np.int_] + The segmentation array to modify in place. + potential_fill : npt.NDArray[np.bool_] + 2D mask of potential fill regions for this slice. + source_binary : npt.NDArray[np.bool_] + 2D binary mask of source structure (e.g., CC). + target_binary : npt.NDArray[np.bool_] + 2D binary mask of target structure (e.g., ventricle). + x_slice : int + The x-coordinate of the current slice. + direction : str + Either 'inferior-superior' (iterate over z) or 'anterior-posterior' (iterate over y). + max_gap_voxels : int + Maximum gap size in voxels for this direction. + fillable_labels : set[int] + Set of label values that can be replaced (e.g., {0, 2, 41} for background and WM). + + Returns + ------- + int + Number of voxels filled. + """ + voxels_filled = 0 + + if direction == 'inferior-superior': + # Iterate over z dimension + for z in range(potential_fill.shape[1]): + potential_fill_line = potential_fill[:, z] + labeled_gaps, num_gaps = ndimage.label(potential_fill_line) + source_line = source_binary[:, z] + target_line = target_binary[:, z] + + for gap_label in range(1, num_gaps + 1): + gap_mask = labeled_gaps == gap_label + + # Check that both source and target are connected to the gap + dilated_gap_mask = ndimage.binary_dilation(gap_mask, iterations=1) + if not np.any(source_line & dilated_gap_mask): + continue + if not np.any(target_line & dilated_gap_mask): + continue + + # Get the target label from adjacent target voxels + target_label_location = np.where(target_line & dilated_gap_mask)[0] + if len(target_label_location) == 0: + continue + target_label = corrected_pred[x_slice, target_label_location[0], z] + + # Check gap size + if np.sum(gap_mask) > max_gap_voxels: + continue + + # Fill voxels that have fillable labels + current_labels = corrected_pred[x_slice, :, z] + fill_mask = gap_mask & np.isin(current_labels, list(fillable_labels)) + voxels_filled += np.sum(fill_mask) + corrected_pred[x_slice, :, z][fill_mask] = target_label + + elif direction == 'anterior-posterior': + # Iterate over y dimension + for y in range(potential_fill.shape[0]): + potential_fill_line = potential_fill[y, :] + labeled_gaps, num_gaps = ndimage.label(potential_fill_line) + source_line = source_binary[y, :] + target_line = target_binary[y, :] + + for gap_label in range(1, num_gaps + 1): + gap_mask = labeled_gaps == gap_label + + # Check that both source and target are connected to the gap + dilated_gap_mask = ndimage.binary_dilation(gap_mask, iterations=1) + if not np.any(source_line & dilated_gap_mask): + continue + if not np.any(target_line & dilated_gap_mask): + continue + + # Get the target label from adjacent target voxels + target_label_location = np.where(target_line & dilated_gap_mask)[0] + if len(target_label_location) == 0: + continue + target_label = corrected_pred[x_slice, y, target_label_location[0]] + + # Check gap size + if np.sum(gap_mask) > max_gap_voxels: + continue + + # Fill voxels that have fillable labels + current_labels = corrected_pred[x_slice, y, :] + fill_mask = gap_mask & np.isin(current_labels, list(fillable_labels)) + voxels_filled += np.sum(fill_mask) + corrected_pred[x_slice, y, :][fill_mask] = target_label + + return voxels_filled + + +def _fill_gaps_between_structures( + corrected_pred: npt.NDArray[np.int_], + source_mask: npt.NDArray[np.bool_], + target_mask: npt.NDArray[np.bool_], + voxel_size: tuple[float, float, float], + close_gap_size_mm: float, + fillable_labels: set[int], + description: str +) -> int: + """Fill small gaps between two structures. + + Parameters + ---------- + corrected_pred : npt.NDArray[np.int_] + The segmentation array to modify in place. + source_mask : npt.NDArray[np.bool_] + 3D binary mask of source structure (e.g., CC). + target_mask : npt.NDArray[np.bool_] + 3D binary mask of target structure (e.g., ventricle or background). + voxel_size : tuple[float, float, float] + Voxel size in mm. + close_gap_size_mm : float + Maximum gap size in mm. + fillable_labels : set[int] + Set of label values that can be replaced. + description : str + Description for logging. + + Returns + ------- + int + Number of voxels filled. + """ + # Convert mm gap size to voxels + max_gap_vox_anterior_posterior = int(np.ceil(close_gap_size_mm / voxel_size[1])) + max_gap_vox_inferior_superior = int(np.ceil(close_gap_size_mm / voxel_size[2])) + max_gap_vox_max = max(max_gap_vox_anterior_posterior, max_gap_vox_inferior_superior) + + voxels_filled = 0 + + # Process each slice independently + for x in range(corrected_pred.shape[0]): + source_slice = source_mask[x] + target_slice = target_mask[x] + + # Skip slices without both structures + if not (source_slice.any() and target_slice.any()): + continue + + # Create binary masks for this slice + source_binary = source_slice.astype(bool) + target_binary = target_slice.astype(bool) + + # Dilate both masks to find potential connection points + source_dilated = ndimage.binary_dilation(source_binary, iterations=max_gap_vox_max) + target_dilated = ndimage.binary_dilation(target_binary, iterations=max_gap_vox_max) + + # Find voxels that are adjacent to both structures but not part of either + potential_fill = (source_dilated & target_dilated) & ~(source_binary | target_binary) + + # Fill gaps in inferior-superior direction + voxels_filled += _fill_gaps_in_direction( + corrected_pred, potential_fill, source_binary, target_binary, + x, 'inferior-superior', max_gap_vox_inferior_superior, fillable_labels + ) + + # Fill gaps in anterior-posterior direction + voxels_filled += _fill_gaps_in_direction( + corrected_pred, potential_fill, source_binary, target_binary, + x, 'anterior-posterior', max_gap_vox_anterior_posterior, fillable_labels + ) + + if voxels_filled > 0: + logger.info(f"Filled {voxels_filled} voxels {description}") + + return voxels_filled + + def correct_wm_ventricles( aseg_cc: npt.NDArray[np.int_], fornix_mask: npt.NDArray[np.bool_], voxel_size: tuple[float, float, float], close_gap_size_mm: float = 3.0 ) -> npt.NDArray[np.int_]: - """Correct WM mask and ventricle labels according to the CC and fornix masks. + """Fill small gaps between corpus callosum, ventricles, and background. - The function - Take non-CC-connected WM components -> remove - Take FN -> WM - Fill space in superior inferior direction between CC and left/right Ventricle with corresponding Ventricle labels - """ + This function performs two gap-filling operations: + 1. Fills WM and background gaps between CC and ventricles with ventricle labels + 2. Fills WM gaps between CC and background with background label + + Note: Fornix and non-CC-connected WM component removal are intentionally not implemented + in this function as they have been removed from the processing pipeline. + + Parameters + ---------- + aseg_cc : npt.NDArray[np.int_] + Aseg segmentation with CC already painted in. + fornix_mask : npt.NDArray[np.bool_] + Mask of the fornix. Not currently used (kept for interface compatibility). + voxel_size : tuple[float, float, float] + Voxel size of the aseg image in mm. + close_gap_size_mm : float, default=3.0 + Maximum size of the gap to fill in millimeters. + Returns + ------- + npt.NDArray[np.int_] + Corrected segmentation map with filled gaps. + """ # Create a copy to avoid modifying the original corrected_pred = aseg_cc.copy() - + # Get CC mask (labels 251-255) cc_mask = mask_in_array(aseg_cc, SUBSEGMENT_LABELS) - # Get left and right ventricle masks - all_ventricle_mask = (aseg_cc == 4) | (aseg_cc == 43) - - # Combine all WM labels - all_wm_mask = (aseg_cc == 2) | (aseg_cc == 41) - - # 1. Fill space between CC and ventricles - # Only fill small gaps (up to 3 voxels) between CC and ventricle boundaries - #for ventricle_label, ventricle_mask in [(4, left_ventricle_mask), (43, right_ventricle_mask)]: + # Get ventricle masks (left=4, right=43) + ventricle_mask = (aseg_cc == 4) | (aseg_cc == 43) - # Process each slice independently - for x in range(corrected_pred.shape[0]): - cc_slice = cc_mask[x] - #vent_slice = ventricle_mask - all_wm_slice = all_wm_mask[x] - - if all_wm_slice.any() and cc_slice.any(): - - # Dilate CC mask to find adjacent voxels, then check for overlap with component - cc_dilated = ndimage.binary_dilation(cc_slice, iterations=1) - # Label connected components in WM - labeled_wm, num_components = ndimage.label(all_wm_slice) - - # Find components that are adjacent to CC and remove them - for label in range(1, num_components + 1): - component_mask = labeled_wm == label - # Check if this component is adjacent to (touches) the CC - if np.any(component_mask & cc_dilated): - corrected_pred[x][component_mask] = 0 # Set to background - - if fornix_mask[x].any(): - fornix_slice = fornix_mask[x] - # count WM labels overlapping with fornix - left_wm_overlap = np.sum(fornix_slice & (aseg_cc == 2)) - right_wm_overlap = np.sum(fornix_slice & (aseg_cc == 41)) - corrected_pred[x][fornix_slice] = 2 + (left_wm_overlap > right_wm_overlap) * 39 # Left WM / Right WM - - vent_slice = all_ventricle_mask - potential_fill = np.asarray([False]) - if cc_slice.any() and vent_slice.any(): - # Create binary masks for this slice - cc_binary = cc_slice.astype(bool) - vent_binary = vent_slice.astype(bool) - - # Dilate both masks slightly to find potential connection points - max_gap_vox = int(np.ceil(voxel_size[1] * close_gap_size_mm)) - cc_dilated = ndimage.binary_dilation(cc_binary, iterations=max_gap_vox) - vent_dilated = ndimage.binary_dilation(vent_binary, iterations=max_gap_vox) - - # Find voxels that are adjacent to both CC and ventricle but not part of either - potential_fill = (cc_dilated & vent_dilated) & ~(cc_binary | vent_binary) - - # Only fill small gaps between CC and ventricle in inferior-superior direction - if not potential_fill.any(): - for z in range(potential_fill.shape[1]): - potential_fill_line = potential_fill[:, z] - labeled_gaps, num_gaps = ndimage.label(potential_fill_line) - cc_line = cc_binary[:, z] - vent_line = vent_binary[:, z] - - for gap_label in range(1, num_gaps + 1): - gap_mask = labeled_gaps == gap_label - - # check that CC and ventricle are connected to the gap_mask - dilated_gap_mask = ndimage.binary_dilation(gap_mask, iterations=1) - if not np.any(cc_line & dilated_gap_mask): - continue - if not np.any(vent_line & dilated_gap_mask): - continue - - vent_label_location = np.where(vent_line & dilated_gap_mask)[0] - vent_label = corrected_pred[x, vent_label_location, z] - - if np.sum(gap_mask) > max_gap_vox: - continue - - corrected_pred[x, :, z][gap_mask & (corrected_pred[x, :, z] == 0)] = vent_label - - # Process gaps in z-direction (within each y-row) - for y in range(potential_fill.shape[0]): - potential_fill_line = potential_fill[y, :] - labeled_gaps, num_gaps = ndimage.label(potential_fill_line) - cc_line = cc_binary[y, :] - vent_line = vent_binary[y, :] - - for gap_label in range(1, num_gaps + 1): - gap_mask = labeled_gaps == gap_label - - # check that CC and ventricle are connected to the gap_mask - dilated_gap_mask = ndimage.binary_dilation(gap_mask, iterations=1) - if not np.any(cc_line & dilated_gap_mask): - continue - if not np.any(vent_line & dilated_gap_mask): - continue - - vent_label_location = np.where(vent_line & dilated_gap_mask)[0] - if len(vent_label_location) > 0: - vent_label = corrected_pred[x, y, vent_label_location[0]] # Take first match - - if np.sum(gap_mask) > max_gap_vox: - continue - - corrected_pred[x, y, :][gap_mask & (corrected_pred[x, y, :] == 0)] = vent_label + # Get background mask + background_mask = aseg_cc == 0 + + # 1. Fill gaps between CC and ventricles (replace WM and background with ventricle labels) + _fill_gaps_between_structures( + corrected_pred, cc_mask, ventricle_mask, voxel_size, close_gap_size_mm, + fillable_labels={0, 2, 41}, # background and WM + description="between CC and ventricles (WM/background → ventricle)" + ) + + # 2. Fill WM gaps between CC and background (replace WM with background) + _fill_gaps_between_structures( + corrected_pred, cc_mask, background_mask, voxel_size, close_gap_size_mm, + fillable_labels={2, 41}, # only WM + description="between CC and background (WM → background)" + ) return corrected_pred @@ -297,11 +427,16 @@ def _is_conform(img, dtype, verbose): if not np.allclose(cc_seg_image.affine, aseg_image.affine): sys.exit("Error: The affine matrices of the aseg and the corpus callosum images are not the same.") - # Paint CC into prediction - pred_with_cc = paint_in_cc(aseg_data, cc_seg_data) + # Count initial labels before any modifications + initial_cc = np.sum(mask_in_array(aseg_data, SUBSEGMENT_LABELS)) + initial_fornix = np.sum(aseg_data == FORNIX_LABEL) + initial_wm = np.sum((aseg_data == 2) | (aseg_data == 41)) + initial_ventricles = np.sum((aseg_data == 4) | (aseg_data == 43)) + + # Paint CC into prediction (modifies aseg_data in place) + paint_in_cc(aseg_data, cc_seg_data) - # Apply WM and ventricle corrections - logger.info("Applying white matter and ventricle corrections...") + # Apply ventricle gap filling corrections fornix_mask = cc_seg_data == FORNIX_LABEL voxel_size = tuple(aseg_image.header.get_zooms()) pred_corrected = correct_wm_ventricles(aseg_data, fornix_mask, voxel_size) @@ -321,28 +456,22 @@ def _is_conform(img, dtype, verbose): else: rta_fut = None - # Count initial labels - initial_cc = np.sum(mask_in_array(aseg_data, SUBSEGMENT_LABELS)) - initial_fornix = np.sum(aseg_data == FORNIX_LABEL) - initial_wm = np.sum((aseg_data == 2) | (aseg_data == 41)) - logger.info(f"Initial segmentation: CC={initial_cc}, Fornix={initial_fornix}, WM={initial_wm}") - - after_paint_cc = np.sum(mask_in_array(pred_with_cc, SUBSEGMENT_LABELS)) - logger.info(f"After painting CC: {after_paint_cc} CC voxels added") - # Count final labels final_cc = np.sum(mask_in_array(pred_corrected, SUBSEGMENT_LABELS)) final_fornix = np.sum(pred_corrected == FORNIX_LABEL) final_wm = np.sum((pred_corrected == 2) | (pred_corrected == 41)) final_ventricles = np.sum((pred_corrected == 4) | (pred_corrected == 43)) - logger.info(f"Final segmentation: CC={final_cc}, Fornix={final_fornix},\ - WM={final_wm}, Ventricles={final_ventricles}") - logger.info(f"Changes: CC +{final_cc-initial_cc}, Fornix {final_fornix-initial_fornix},\ - WM {final_wm-initial_wm}") + wm_change = final_wm - initial_wm + vent_change = final_ventricles - initial_ventricles + cc_change = final_cc - initial_cc + + logger.info(f"Changes: Corpus Callosum {'+' if cc_change >= 0 else ''}{cc_change}, " + f"White Matter {'+' if wm_change >= 0 else ''}{wm_change}, " + f"Ventricles {'+' if vent_change >= 0 else ''}{vent_change}") + # Wait for all IO operations to complete + io_fut.result() if rta_fut is not None: - _ = rta_fut.result() - - sys.exit(0) + rta_fut.result() From 8757466ab33557d7ede885fa90eca2006dea420b Mon Sep 17 00:00:00 2001 From: ClePol Date: Mon, 22 Dec 2025 16:54:37 +0100 Subject: [PATCH 54/68] visualization & commandline interface bugfixes --- CorpusCallosum/cc_visualization.py | 25 +++++++++++++++++++------ CorpusCallosum/fastsurfer_cc.py | 1 + CorpusCallosum/shape/contour.py | 28 +++++++++++++++------------- 3 files changed, 35 insertions(+), 19 deletions(-) diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index d05305c2..bc4baeb6 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -119,13 +119,30 @@ def load_contours_from_template_dir( fsaverage_contour = None contours: list[CCContour] = [] + # First pass: collect all indices to determine the range + indices = [] for thickness_file in thickness_files: + try: + idx = int(thickness_file.stem.split("_")[-1]) + indices.append(idx) + except ValueError: + # skip files that do not follow the expected naming + continue + + # Calculate z_positions centered around the middle slice + num_slices = len(indices) + middle_idx = num_slices // 2 + + for i, thickness_file in enumerate(thickness_files): try: idx = int(thickness_file.stem.split("_")[-1]) except ValueError: # skip files that do not follow the expected naming continue + # Calculate z_position: use the index offset from middle, scaled by resolution + z_position = (idx - indices[middle_idx]) * resolution + contour_file = template_dir / f"contour_{idx}.txt" if not contour_file.exists(): @@ -138,15 +155,11 @@ def load_contours_from_template_dir( # create measurement points (points = 2 x levelpaths) according to number of thickness values fsaverage_contour.create_levelpaths(num_points=num_thickness_values // 2, update_data=True) current_contour = fsaverage_contour.copy() + current_contour.z_position = z_position current_contour.load_thickness_values(thickness_file) else: - # this is kinda ugly - maybe we need to overload the constructor to load the contour and thickness values? - # FIXME: The z_position in from_contour is still incorrect, currently all Contours would be "registered" for - # the midslice. - current_contour = CCContour.from_contour_file(contour_file, thickness_file, z_position=0.0) - # current_contour.load_contour(contour_file) - # current_contour.load_thickness_values(thickness_file) + current_contour = CCContour.from_contour_file(contour_file, thickness_file, z_position=z_position) current_contour.fill_thickness_values() contours.append(current_contour) diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index d8cbdb80..c293ba0d 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -148,6 +148,7 @@ def _set_help_sid(action): parser.add_argument( "--subdivisions", type=float, + nargs='*', metavar="FRAC", default=_FixFloatFormattingList([1 / 6, 1 / 2, 2 / 3, 3 / 4], ".3f"), help="List of subdivision fractions for the corpus callosum subsegmentation." diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py index 97ca7be6..b1f14b2a 100644 --- a/CorpusCallosum/shape/contour.py +++ b/CorpusCallosum/shape/contour.py @@ -714,20 +714,22 @@ def _load_thickness_values( else: raise ValueError("Thickness values file must contain a single column") - if len(values) != len(contour): - if original_thickness_vertices is None: - new_values = values - elif np.sum(~np.isnan(values)) == len(original_thickness_vertices): - new_values = np.full(len(contour), np.nan) - new_values[original_thickness_vertices] = values[~np.isnan(values)] - else: - raise ValueError( - f"Number of thickness values {len(values)} does not match number of points in the contour " - f"{len(contour)} and current number of measurement points {len(original_thickness_vertices)} does " - f"not match the number of set thickness values {np.sum(~np.isnan(values))}." - ) + if len(values) == len(contour): + # Perfect match - use values directly + new_values = values + elif original_thickness_vertices is None: + # No original vertices specified, use values as-is (may differ in length) + new_values = values + elif np.sum(~np.isnan(values)) == len(original_thickness_vertices): + # Values match the number of measurement points, map them to the contour + new_values = np.full(len(contour), np.nan) + new_values[original_thickness_vertices] = values[~np.isnan(values)] else: - raise ValueError(f"Number of thickness values in {input_path} does not match the vertices of the path!") + raise ValueError( + f"Number of thickness values {len(values)} does not match number of points in the contour " + f"{len(contour)} and current number of measurement points {len(original_thickness_vertices)} does " + f"not match the number of set thickness values {np.sum(~np.isnan(values))}." + ) return new_values From aa427e733008ed3c80b5372b185ba6d5d6e1ec67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Mon, 22 Dec 2025 17:32:55 +0100 Subject: [PATCH 55/68] - Fix saving the cc_surf (only saved when thickness_image was requested - Resolved outstanding issues where vox_size was used, but correct application of vox2ras should have been used (subdivision mask, cc_image) - Fix cc_image - Fix subdivision mask - vectorize get_unique_contour_points - Remove resolved FIXME comments - Remove unused variables --- CorpusCallosum/cc_visualization.py | 2 +- CorpusCallosum/fastsurfer_cc.py | 38 ++++---- CorpusCallosum/shape/postprocessing.py | 118 +++++++++---------------- CorpusCallosum/utils/visualization.py | 88 +++++++++--------- 4 files changed, 110 insertions(+), 136 deletions(-) diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index bc4baeb6..f17af60f 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -133,7 +133,7 @@ def load_contours_from_template_dir( num_slices = len(indices) middle_idx = num_slices // 2 - for i, thickness_file in enumerate(thickness_files): + for thickness_file in thickness_files: try: idx = int(thickness_file.stem.split("_")[-1]) except ValueError: diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index c293ba0d..55473043 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -18,7 +18,7 @@ from collections.abc import Iterable from pathlib import Path from time import perf_counter_ns -from typing import Literal, TypeVar, cast +from typing import Literal, TypeVar, cast, get_args import nibabel as nib import numpy as np @@ -83,6 +83,18 @@ _TPathLike = TypeVar("_TPathLike", str, Path, Literal[None]) +CCMeasures = Literal[ + "areas", + "thickness", + "curvature", + "midline_length", + "circularity", + "cc_index", + "total_area", + "total_perimeter", + "thickness_profile", +] + class ArgumentDefaultsHelpFormatter(HelpFormatter): """Help message formatter which adds default values to argument help.""" @@ -755,7 +767,6 @@ def main( # start saving upright volume, this is the image in fsaverage space but not yet oriented via AC-PC if sd.has_attribute("upright_volume"): # upright == fsaverage-aligned - # FIXME: upright currently does not get saved correctly io_futures.append( thread_executor().submit( apply_transform_to_volume, @@ -841,7 +852,6 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: subdivisions=subdivisions, subdivision_method=cast(SubdivisionMethod, subdivision_method), contour_smoothing=contour_smoothing, - vox_size=vox_size, subject_dir=sd, ) io_futures.extend(slice_io_futures) @@ -859,7 +869,7 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: cc_subseg_midslice = make_subdivision_mask( (cc_fn_seg_labels.shape[1], cc_fn_seg_labels.shape[2]), middle_slice_result["split_contours"], - vox_size[1:3], + vox2ras=fsavg_vox2ras @ np.linalg.inv(fsavg2midslice_vox2vox), ) else: logger.warning("Too many subsegments for lookup table, skipping sub-division of output segmentation.") @@ -886,24 +896,14 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: orig2midslice_vox2vox=orig2midslice_vox2vox, )) - METRICS = [ - "areas", - "thickness", - "curvature", - "midline_length", - "circularity", - "cc_index", - "total_area", - "total_perimeter", - "thickness_profile", - ] + metrics: tuple[CCMeasures] = get_args(CCMeasures) # Record key metrics for middle slice - output_metrics_middle_slice = {metric: middle_slice_result[metric] for metric in METRICS} + output_metrics_middle_slice = {metric: middle_slice_result[metric] for metric in metrics} # Create enhanced output dictionary with all slice results per_slice_output_dict = { - "slices": [convert_numpy_to_json_serializable({metric: result[metric] for metric in METRICS}) + "slices": [convert_numpy_to_json_serializable({metric: result[metric] for metric in metrics}) for result in slice_results], } @@ -960,14 +960,14 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: save_cc_measures_json, sd.filename_by_attribute('cc_mid_measures'), output_metrics_middle_slice | additional_metrics, - )) + )) if sd.has_attribute("cc_measures"): io_futures.append(thread_executor().submit( save_cc_measures_json, sd.filename_by_attribute("cc_measures"), per_slice_output_dict | additional_metrics, - )) + )) # save lta to fsaverage space diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index c4ad81da..14bc3999 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -36,14 +36,7 @@ transform_to_acpc_standard, ) from CorpusCallosum.shape.thickness import cc_thickness -from CorpusCallosum.utils.types import ( - CCMeasuresDict, - ContourThickness, - Points2dType, - Polygon2dType, - SliceSelection, - SubdivisionMethod, -) +from CorpusCallosum.utils.types import CCMeasuresDict, ContourThickness, Points2dType, SliceSelection, SubdivisionMethod from CorpusCallosum.utils.visualization import plot_contours from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Mask2d, Shape2d, Shape3d, Vector2d from FastSurferCNN.utils.common import SubjectDirectory, update_docstring @@ -100,7 +93,6 @@ def recon_cc_surf_measures_multi( subdivision_method: SubdivisionMethod, contour_smoothing: int, subject_dir: SubjectDirectory, - vox_size: tuple[float, float, float], ) -> tuple[list[CCMeasuresDict], list[concurrent.futures.Future]]: """Surface reconstruction and metrics computation of corpus callosum slices based on selection mode. @@ -134,8 +126,6 @@ def recon_cc_surf_measures_multi( Gaussian sigma for contour smoothing. subject_dir : SubjectDirectory The SubjectDirectory object managing file names in the subject directory. - vox_size : 3-tuple of floats - LIA-oriented voxel size in millimeters (x, y, z). Returns ------- @@ -162,7 +152,6 @@ def recon_cc_surf_measures_multi( subdivisions=subdivisions, subdivision_method=subdivision_method, contour_smoothing=contour_smoothing, - vox_size=vox_size, ) # Process multiple slices or specific slice @@ -224,14 +213,16 @@ def _gen_slice2slab_vox2vox(_slice_idx: int) -> AffineMatrix4x4: io_futures.append( run( plot_contours, - transformed=midslices[current_slice_in_volume:current_slice_in_volume+1], + # select the data of the current slice + slice_or_slab=midslices[[current_slice_in_volume]], + # the following need to be in voxel coordinates... split_contours=cc_measures["split_contours"], midline_equidistant=cc_measures["midline_equidistant"], levelpaths=cc_measures["levelpaths"], output_path=qc_imgs, - ac_coords=ac_coords_vox, - pc_coords=pc_coords_vox, - vox_size=vox_size, + ac_coords_vox=ac_coords_vox, + pc_coords_vox=pc_coords_vox, + vox2ras=this_slice_vox2ras, title=f"CC Subsegmentation by {subdivision_method} (Slice {slice_idx + 1})", ) ) @@ -244,8 +235,6 @@ def _gen_slice2slab_vox2vox(_slice_idx: int) -> AffineMatrix4x4: f"thickness_measurement_points.txt) to {template_dir}") run = run for j in range(len(cc_contours)): - # FIXME: check, if this is fixed (thickness values not nan == 200) - # this does not seem to be thread-safe, do not parallelize! io_futures.append(run(cc_contours[j].save_contour, template_dir / f"contour_{j}.txt")) io_futures.append(run(cc_contours[j].save_thickness_values, template_dir / f"thickness_values_{j}.txt")) @@ -267,17 +256,14 @@ def _gen_slice2slab_vox2vox(_slice_idx: int) -> AffineMatrix4x4: logger.info(f"Saving overlay file to {overlay_file_path}") io_futures.append(run(cc_mesh.write_morph_data, overlay_file_path)) - if any(wants_output(f"cc_{n}") for n in ("thickness_image", "cc_surf")): + if any(wants_output(f"cc_{n}") for n in ("thickness_image", "surf")): import nibabel as nib up_data: Image3d[np.uint8] = np.empty(upright_header["dims"][:3], dtype=upright_header.get_data_dtype()) upright_img = nib.MGHImage(up_data, fsavg_vox2ras, upright_header) # the mesh is generated in upright coordinates, so we need to also transform to orig coordinates - # FIXME: this is currently not in RAS coordinates! - # Mesh is fsavg_midplane (RAS); we need to transform to voxel coordinates # fsavg ras is also on the midslice, so this is fine and we multiply in the IA and SP offsets cc_mesh = cc_mesh.to_vox_coordinates(mesh_ras2vox=np.linalg.inv(fsavg_vox2ras @ orig2fsavg_vox2vox)) - #FIXME: to_fs_coordinate needs to transform from upright to if wants_output("cc_thickness_image"): # this will also write overlay and surface thickness_image_path = output_path("cc_thickness_image") @@ -318,7 +304,6 @@ def recon_cc_surf_measure( subdivisions: list[float], subdivision_method: SubdivisionMethod, contour_smoothing: int, - vox_size: tuple[float, float, float], ) -> tuple[CCMeasuresDict, ContourThickness, tuple[int, int]]: """Reconstruct surfaces and compute measures for a single slice for the corpus callosum. @@ -342,8 +327,6 @@ def recon_cc_surf_measure( Method for contour subdivision ('shape', 'vertical', 'angular', or 'eigenvector'). contour_smoothing : int Gaussian sigma for contour smoothing. - vox_size : triplet of floats - LIA-oriented voxel size in millimeters. Returns ------- @@ -478,7 +461,7 @@ def recon_cc_surf_measure( def test_right_of_line( - coords: Polygon2dType, + coords: Points2dType, line_start: Vector2d, line_end: Vector2d, ) -> np.ndarray[tuple[int], np.dtype[np.bool_]]: @@ -487,27 +470,28 @@ def test_right_of_line( Parameters ---------- coords : np.ndarray - Array of coordinates of shape (2, N). + Array of coordinates of shape (..., N). line_start : array-like - [x, y] coordinates of line start point. + [x, y] coordinates of line start point (N,). line_end : array-like - [x, y] coordinates of line end point. + [x, y] coordinates of line end point (N,). Returns ------- np.ndarray - Boolean array where True means point is to the left of the line. + Boolean array where True means point is to the left of the line of shape coords.shape[:-1]. """ # Vector from line_start to line_end - line_vec = np.array(line_end) - np.array(line_start) + line_start_arr = np.expand_dims(line_start, axis=np.arange(line_start.ndim, coords.ndim).tolist()) + line_vec = np.expand_dims(line_end, axis=np.arange(line_end.ndim, coords.ndim).tolist()) - line_start_arr # Vectors from line_start to all points (vectorized) - point_vec = coords - np.expand_dims(line_start, axis=list(range(1, coords.ndim))) + point_vec = np.moveaxis(coords, -1, 0) - line_start_arr # Cross product (vectorized): positive means point is to the left of the line cross_products = line_vec[0] * point_vec[1] - line_vec[1] * point_vec[0] - return cross_products > 0 + return np.greater(cross_products, 0) def get_unique_contour_points(split_contours: ContourList) -> list[Points2dType]: @@ -535,33 +519,16 @@ def get_unique_contour_points(split_contours: ContourList) -> list[Points2dType] 3. Collects points unique to each subsegment. """ # For each contour point, check if it appears in other contours - unique_contour_points: list[Points2dType] = [] - - for i, contour in enumerate(split_contours): - # Get points for this contour - contour_points: Points2dType = np.vstack((contour[0], -contour[1])).T # Shape: (N,2) - - # Check each point against all other contours - unique_points = [] - for point in contour_points: - is_unique = True - - # Compare against other contours - for j, other_contour in enumerate(split_contours): - if i == j: - continue - - other_points = np.vstack((other_contour[0], -other_contour[1])).T - - # Check if point exists in other contour (with small tolerance) - if np.any(np.all(np.abs(other_points - point) < 1e-6, axis=1)): - is_unique = False - break - - if is_unique: - unique_points.append(point) - - unique_contour_points.append(np.array(unique_points)) + # initialize with values for first_contour, which are by definition just "the contour" (empty) + unique_contour_points: list[Points2dType] = [np.zeros((0, 2))] + first_contour = split_contours[0] + # Check each point against all other contours + for contour in split_contours[1:]: + # 0: coord-axis, 1: contour-axis, 2: first_contour_axis + contour_comparison = np.isclose(first_contour[:, None], contour[:, :, None], atol=1e-6) + # mask of contour points, that are also in first_contour (axis 1 after all) + contour_points_in_first_contour_mask = np.any(np.all(contour_comparison, axis=0), axis=1) + unique_contour_points.append(contour[:, ~contour_points_in_first_contour_mask].T) return unique_contour_points @@ -569,7 +536,7 @@ def get_unique_contour_points(split_contours: ContourList) -> list[Points2dType] def make_subdivision_mask( slice_shape: Shape2d, split_contours: ContourList, - vox_size: tuple[float, float], + vox2ras: AffineMatrix4x4, ) -> np.ndarray[Shape2d, np.dtype[np.int_]]: """Create a mask for subdividing the corpus callosum based on split contours. @@ -580,8 +547,8 @@ def make_subdivision_mask( split_contours : ContourList List of contours defining the subdivisions. Each contour is a tuple of x and y coordinates. - vox_size : pair of floats - The voxel sizes of the image grid in AS orientation. + vox2ras : AffineMatrix4x4 + The vox2ras transformation matrix for the requested shape. Returns ------- @@ -599,6 +566,7 @@ def make_subdivision_mask( - Tests which points lie to the right of the line. - Updates labels for those points. """ + from nibabel.affines import apply_affine # unique_contour_points are the points where sub-division lines were inserted unique_contour_points: list[Points2dType] = get_unique_contour_points(split_contours) # shape (N, 2) @@ -610,29 +578,27 @@ def make_subdivision_mask( # Create coordinate grids for all points in the slice rows, cols = slice_shape - coords = np.array(np.mgrid[0:rows, 0:cols])[[1, 0]] + coords_vox = np.stack(np.mgrid[0:1, 0:rows, 0:cols], axis=-1) + coords_ras = apply_affine(vox2ras, coords_vox) + + cc_labels_posterior_to_anterior = SUBSEGMENT_LABELS - cc_subsegment_lut_anterior_to_posterior = SUBSEGMENT_LABELS.copy() - cc_subsegment_lut_anterior_to_posterior.reverse() - # Initialize with first segment label - subdivision_mask = np.full(slice_shape, cc_subsegment_lut_anterior_to_posterior[0], dtype=np.int32) - + subdivision_mask = np.full(slice_shape, cc_labels_posterior_to_anterior[0], dtype=np.int32) + # Process each subdivision line, subdivision_segments has for each division line the two points that are on the # contour and divide the subsegments - for segment_idx, segment_points in enumerate(subdivision_segments): + for label, segment_points in zip(cc_labels_posterior_to_anterior[1:], reversed(subdivision_segments), strict=True): # line_start and line_end are the intersection points of the CC subsegmentation boundary and the contour line + line_start, line_end = segment_points + # --> find all voxels posterior to the line in question - line_start: Vector2d = segment_points[0] / vox_size - line_end: Vector2d = segment_points[-1] / vox_size - # Vectorized test: find all points to the right of line (line_start->line_end) # right_of_line == posterior to line - points_right_of_line = test_right_of_line(coords, line_start, line_end) + points_right_of_line = test_right_of_line(coords_ras[0, ..., 1:], line_start, line_end) # All points to the right of this line belong to the next segment or beyond - subdivision_mask[points_right_of_line] = cc_subsegment_lut_anterior_to_posterior[segment_idx + 1] - + subdivision_mask[points_right_of_line] = label return subdivision_mask diff --git a/CorpusCallosum/utils/visualization.py b/CorpusCallosum/utils/visualization.py index 3a44c634..a71b3543 100644 --- a/CorpusCallosum/utils/visualization.py +++ b/CorpusCallosum/utils/visualization.py @@ -19,6 +19,9 @@ import nibabel as nib import numpy as np +from CorpusCallosum.utils.types import ContourList, Polygon2dType +from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Vector2d + def plot_standardized_space( ax_row: list[plt.Axes], @@ -127,37 +130,38 @@ def visualize_coordinate_spaces( def plot_contours( - transformed: np.ndarray, - split_contours: list[np.ndarray] | None = None, - midline_equidistant: np.ndarray | None = None, - levelpaths: list[np.ndarray] | None = None, - output_path: str | Path | list[Path] | None = None, - ac_coords: np.ndarray | None = None, - pc_coords: np.ndarray | None = None, - vox_size: tuple[float, float, float] | None = None, + slice_or_slab: Image3d, + split_contours: ContourList | None = None, + midline_equidistant: Polygon2dType | None = None, + levelpaths: list[Polygon2dType] | None = None, + output_path: str | Path | list[Path | str] | None = None, + ac_coords_vox: Vector2d | None = None, + pc_coords_vox: Vector2d | None = None, + vox2ras: AffineMatrix4x4 | None = None, title: str = "", ) -> None: """Creates a figure of the contours (shape) and the subdivisions of the corpus callosum. Parameters ---------- - transformed : np.ndarray - Transformed image data. + slice_or_slab : np.ndarray + Intensities of the current slice, midslice or midslab (will plot middle slice). split_contours : list[np.ndarray], optional - List of contour arrays for each subdivision (ignore contours on None). + List of contour arrays for each subdivision (ignore contours on None) in upright AS coordinates each with shape + (N, 2). midline_equidistant : np.ndarray, optional - Midline points at equidistant spacing (ignore midline on None). + Midline points at equidistant spacing (ignore midline on None) in upright AS coordinates with shape (2, N). levelpaths : list[np.ndarray], optional - List of level paths for visualization (ignore level paths on None). + List of level paths for visualization (ignore level paths on None) in upright AS coordinates each with shape + (2, N). output_path : str or Path or list of Paths, optional - Path to save the plot (do not save on None). - ac_coords : np.ndarray, optional - AC coordinates for visualization (ignore AC on None). - pc_coords : np.ndarray, optional - PC coordinates for visualization (ignore PC on None). - vox_size : triplet of floats, optional - LIA-oriented voxel size for scaling, optional if none of split_contours, midline_equidistant, or levelpaths are - provided. + Path to save the plot (show and do not save on None). + ac_coords_vox : np.ndarray, optional + AC coordinates for visualization (ignore AC on None) in LIA voxel coordinates. + pc_coords_vox : np.ndarray, optional + PC coordinates for visualization (ignore PC on None) in LIA voxel coordinates. + vox2ras : AffineMatrix4x4, optional + Slice vox2ras transformation matrix. title : str, default="" Title for the plot. @@ -166,22 +170,25 @@ def plot_contours( Creates a visualization of the corpus callosum contours and their subdivisions. If output_path is provided, saves the plot to that location. """ + from functools import partial + + from nibabel.affines import apply_affine - if vox_size is None and None in (split_contours, midline_equidistant, levelpaths): + if vox2ras is None and None in (split_contours, midline_equidistant, levelpaths): raise ValueError("vox_size must be provided if split_contours, midline_equidistant, or levelpaths are given.") if output_path is not None: matplotlib.use('Agg') # Use non-GUI backend # convert vox_size from LIA to AS - vox_size_ras = np.asarray([vox_size[0], vox_size[2], vox_size[1]]) if vox_size is not None else None + ras2vox = partial(apply_affine, np.linalg.inv(vox2ras)[1:, 1:]) # scale contour data by vox_size to convert from AS to AS-aligned voxel space - _split_contours = [] if split_contours is None else [sp / vox_size_ras[1:, None] for sp in split_contours] - _midline_equi = np.zeros((0, 2)) if midline_equidistant is None else midline_equidistant / vox_size_ras[None, 1:] - _levelpaths = [] if levelpaths is None else [lp / vox_size_ras[None, 1:] for lp in levelpaths] + _split_contours = [] if split_contours is None else [ras2vox(sp.T).T for sp in split_contours] + _midline_equi = np.zeros((0, 2)) if midline_equidistant is None else ras2vox(midline_equidistant) + _levelpaths = [] if levelpaths is None else [ras2vox(lp) for lp in levelpaths] - has_first_plot = not (len(_split_contours) == 0 and ac_coords is None and pc_coords is None) + has_first_plot = not (len(_split_contours) == 0 and ac_coords_vox is None and pc_coords_vox is None) num_plots = 1 + int(has_first_plot) fig, ax = plt.subplots(1, num_plots, sharex=True, sharey=True, figsize=(15, 10)) @@ -189,29 +196,30 @@ def plot_contours( # NOTE: For all plots imshow shows y inverted current_plot = 0 + # This visualization uses voxel coordinates in fsaverage space... if has_first_plot: - ax[current_plot].imshow(transformed[transformed.shape[0] // 2], cmap="gray") + ax[current_plot].imshow(slice_or_slab[slice_or_slab.shape[0] // 2], cmap="gray") ax[current_plot].set_title(title) if _split_contours: for i, this_contour in enumerate(_split_contours): - ax[current_plot].fill(this_contour[0, :], -this_contour[1, :], color="steelblue", alpha=0.25) + ax[current_plot].fill(this_contour[1, :], this_contour[0, :], color="steelblue", alpha=0.25) kwargs = {"color": "mediumblue", "linewidth": 0.7, "linestyle": "solid" if i != 0 else "dotted"} - ax[current_plot].plot(this_contour[0, :], -this_contour[1, :], **kwargs) - if ac_coords is not None: - ax[current_plot].scatter(ac_coords[1], ac_coords[0], color="red", marker="x") - if pc_coords is not None: - ax[current_plot].scatter(pc_coords[1], pc_coords[0], color="blue", marker="x") + ax[current_plot].plot(this_contour[1, :], this_contour[0, :], **kwargs) + if ac_coords_vox is not None: + ax[current_plot].scatter(ac_coords_vox[1], ac_coords_vox[0], color="red", marker="x") + if pc_coords_vox is not None: + ax[current_plot].scatter(pc_coords_vox[1], pc_coords_vox[0], color="blue", marker="x") current_plot += int(has_first_plot) - ax[current_plot].imshow(transformed[transformed.shape[0] // 2], cmap="gray") + ax[current_plot].imshow(slice_or_slab[slice_or_slab.shape[0] // 2], cmap="gray") for this_path in _levelpaths: - ax[current_plot].plot(this_path[:, 0], -this_path[:, 1], color="brown", linewidth=0.8) + ax[current_plot].plot(this_path[:, 1], this_path[:, 0], color="brown", linewidth=0.8) ax[current_plot].set_title("Midline & Levelpaths") if _midline_equi.shape[0] > 0: - ax[current_plot].plot(_midline_equi[:, 0], -_midline_equi[:, 1], color="red") + ax[current_plot].plot(_midline_equi[:, 1], _midline_equi[:, 0], color="red") if _split_contours: reference_contour = _split_contours[0] - ax[current_plot].plot(reference_contour[0, :], -reference_contour[1, :], color="red", linewidth=0.5) + ax[current_plot].plot(reference_contour[1, :], reference_contour[0, :], color="red", linewidth=0.5) padding = 30 for a in ax.flatten(): @@ -220,8 +228,8 @@ def plot_contours( if _split_contours: reference_contour = _split_contours[0] # get bounding box of contours - a.set_xlim(reference_contour[0, :].min() - padding, reference_contour[0, :].max() + padding) - a.set_ylim((-reference_contour[1, :]).max() + padding, (-reference_contour[1, :]).min() - padding) + a.set_xlim(reference_contour[1, :].min() - padding, reference_contour[1, :].max() + padding) + a.set_ylim((reference_contour[0, :]).max() + padding, (reference_contour[0, :]).min() - padding) if output_path is None: return plt.show() From d85068714ee2b75c7dd804b6f75cfe3d3ca72800 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Mon, 22 Dec 2025 17:56:29 +0100 Subject: [PATCH 56/68] Fix and add explanation on how change cc_visualization.py --- CorpusCallosum/cc_visualization.py | 11 +++++++++-- CorpusCallosum/data/read_write.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index f17af60f..aafa2966 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -5,7 +5,6 @@ import numpy as np -from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE from CorpusCallosum.data.fsaverage_cc_template import load_fsaverage_cc_template from CorpusCallosum.shape.contour import CCContour from CorpusCallosum.shape.mesh import CCMesh @@ -222,7 +221,15 @@ def main( cc_mesh.plot_mesh(output_path=str(output_dir / "cc_mesh.html"), **plot_kwargs) #FIXME: needs to be adapted to new interface of CCMesh.to_fs_coordinates / to_vox_coordinates - cc_mesh = cc_mesh.to_vox_coordinates(lr_offset=FSAVERAGE_MIDDLE / resolution) + # Here we need to load the np.linalg.inv(fsavg_vox2ras @ orig2fsavg_vox2vox) + # This is the same as orig2fsavg_ras2ras from cc_up.lta + # orig2fsavg_ras2ras = read_lta(output_dir / "mri/transforms/cc_up.lta") + # orig = nibabel.load(output_dir / "mri/orig.mgz") + # cc_mesh = cc_mesh.to_vox_coordinates(mesh_ras2vox=np.linalg.inv(orig2fsavg_ras2ras @ orig.affine)) + # If we are willing to screenshot here in fsavg space, this can be simplified to just fsavg_vox2ras + from CorpusCallosum.data.read_write import load_fsaverage_data + fsavg_vox2ras, _ = load_fsaverage_data(Path(__file__).parent / "data/fsaverage_data.json") + cc_mesh = cc_mesh.to_vox_coordinates(mesh_ras2vox=np.linalg.inv(fsavg_vox2ras)) logger.info(f"Writing vtk file to {output_dir / 'cc_mesh.vtk'}") cc_mesh.write_vtk(str(output_dir / "cc_mesh.vtk")) logger.info(f"Writing freesurfer surface file to {output_dir / 'cc_mesh.fssurf'}") diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py index 673d7d31..0c20c497 100644 --- a/CorpusCallosum/data/read_write.py +++ b/CorpusCallosum/data/read_write.py @@ -129,7 +129,7 @@ def load_fsaverage_centroids(centroids_path: str | Path) -> dict[int, npt.NDArra return {int(label): np.array(centroid) for label, centroid in centroids_data.items()} -def load_fsaverage_affine(affine_path: str | Path) -> npt.NDArray[float]: +def load_fsaverage_affine(affine_path: str | Path) -> AffineMatrix4x4: """Load fsaverage affine matrix from static text file. Parameters From 90ab5d793d674884e1867a0aac9613fdad948ab1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Mon, 22 Dec 2025 19:07:50 +0100 Subject: [PATCH 57/68] Fix doc build error. --- CorpusCallosum/shape/mesh.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/CorpusCallosum/shape/mesh.py b/CorpusCallosum/shape/mesh.py index 9d3502f0..9b072a0a 100644 --- a/CorpusCallosum/shape/mesh.py +++ b/CorpusCallosum/shape/mesh.py @@ -678,9 +678,6 @@ def to_vox_coordinates( @update_docstring(parent_doc=TriaMesh.write_fssurf.__doc__) def write_fssurf(self, filename: Path | str, image: str | nibabelImage | None = None) -> None: """{parent_doc} - - Notes - ----- Also creates parent directory if needed before writing the file. """ self.__make_parent_folder(filename) From cd7c6e9583a1018df48cd48d60159e14bc16a62b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Mon, 22 Dec 2025 19:05:24 +0100 Subject: [PATCH 58/68] Update CorpusCallosum to use the new lapy.Polygon interface. --- CorpusCallosum/shape/contour.py | 28 +++++++++++----------- CorpusCallosum/shape/endpoint_heuristic.py | 2 ++ env/fastsurfer.yml | 2 +- pyproject.toml | 4 ++-- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py index b1f14b2a..5bfbf21d 100644 --- a/CorpusCallosum/shape/contour.py +++ b/CorpusCallosum/shape/contour.py @@ -418,7 +418,8 @@ def plot_contour_colorfill( continue # make levelpath - path = lapy.TriaMesh._TriaMesh__resample_polygon(path, 1000) + # FIXME: this change to lapy.Polygon is untested + path = lapy.Polygon(path).resample(1000).points # Extend at the beginning: add point in direction opposite to first segment first_segment = path[1] - path[0] @@ -767,20 +768,19 @@ def from_mask_and_acpc( Expects LIA orientation. """ import skimage.measure + from nibabel.affines import apply_affine - from CorpusCallosum.shape.endpoint_heuristic import smooth_contour + _contour: Points2dType = skimage.measure.find_contours(cc_mask, level=0.5)[0] - contour = skimage.measure.find_contours(cc_mask, level=0.5)[0].T - #FIXME: maybe use Polygon.smooth_* - contour = np.array(smooth_contour(*contour, window_size=contour_smoothing)) - # Add z=0 coordinate to make 3D, then remove it after resampling - contour_3d = np.concatenate([np.zeros((1, contour.shape[1])), contour]) # ZIA, (3, N) - # FIXME: change this to using Polygon class when we upgrade lapy - contour_3d = lapy.tria_mesh.TriaMesh._TriaMesh__resample_polygon(contour_3d.T, 701).T - contour_ras = (slice_vox2ras[:3, :3] @ contour_3d) + slice_vox2ras[:3, [3]] + # FIXME: maybe CCContour should just inherit from Polygon? + polygon = lapy.polygon.Polygon(np.concatenate([np.zeros_like(_contour[:, :1]), _contour], axis=1), closed=True) + polygon.smooth_laplace(n=contour_smoothing, inplace=True) + polygon.resample(701, inplace=True) - ac_pc_3d = np.concatenate([[[0, 0]], np.stack([ac_2d, pc_2d], axis=1)]) # (3, 2) - ac_ras, pc_ras = ((slice_vox2ras[:3, :3] @ ac_pc_3d) + slice_vox2ras[:3, [3]]).T - endpoint_idx = find_cc_endpoints(contour_ras[1:], ac_ras[1:], pc_ras[1:]) + contour_ras = apply_affine(slice_vox2ras, polygon.points) - return cls(contour_ras[1:].T, None, endpoint_idx, z_position=slice_vox2ras[0, 3]) + ac_pc_3d = np.concatenate([[[0], [0]], np.stack([ac_2d, pc_2d], axis=0)], axis=1) # (2, 3) + ac_ras, pc_ras = apply_affine(slice_vox2ras, ac_pc_3d) + endpoint_idx = find_cc_endpoints(contour_ras[:, 1:].T, ac_ras[1:], pc_ras[1:]) + + return cls(contour_ras[:, 1:], None, endpoint_idx, z_position=slice_vox2ras[0, 3]) diff --git a/CorpusCallosum/shape/endpoint_heuristic.py b/CorpusCallosum/shape/endpoint_heuristic.py index f9dc084c..d6e84227 100644 --- a/CorpusCallosum/shape/endpoint_heuristic.py +++ b/CorpusCallosum/shape/endpoint_heuristic.py @@ -272,6 +272,8 @@ def find_contour_and_endpoints( # Add z=0 coordinate to make 3D, then remove it after resampling #FIXME: change this to using Polygon class when we upgrade lapy + # IMPORTANT: this is incompatible with lapy Polygon update (lapy 1.5?), however, find_contour_and_endpoints is not + # used any more in favor of CCContour.from_cc_mask contour_3d = lapy.tria_mesh.TriaMesh._TriaMesh__resample_polygon( np.append(contour, np.zeros((1, contour.shape[1])), axis=0).T, 701, diff --git a/env/fastsurfer.yml b/env/fastsurfer.yml index a3c5231d..eb637803 100644 --- a/env/fastsurfer.yml +++ b/env/fastsurfer.yml @@ -5,7 +5,6 @@ channels: dependencies: - h5py==3.12.1 -- lapy==1.4.0 - matplotlib==3.10.1 - monai==1.4.0 - nibabel==5.3.2 @@ -26,6 +25,7 @@ dependencies: - yacs==0.1.8 - pip: - --extra-index-url https://download.pytorch.org/whl/cu126 + - lapy=git+https://github.com/Deep-MI/lapy.git@main - simpleitk==2.4.1 - torch==2.6.0+cu126 - torchio==0.20.4 diff --git a/pyproject.toml b/pyproject.toml index 46dd1c40..c73b15f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,10 +33,10 @@ classifiers = [ ] dependencies = [ 'h5py>=3.7', - "lapy>=1.4.0", + "lapy @ git+https://github.com/Deep-MI/lapy.git@main", 'matplotlib>=3.7.1', 'nibabel>=5.1.0', - 'numpy>=1.25,<2', + 'numpy>=1.25', 'pandas>=1.5.3', 'pyyaml>=6.0', 'requests>=2.31.0', From eb411a5dd14f120f91e30fe6d1b2419cfbbda2b6 Mon Sep 17 00:00:00 2001 From: ClePol Date: Tue, 23 Dec 2025 11:12:34 +0100 Subject: [PATCH 59/68] fixed for subsegmentation and lapy+cc contour --- CorpusCallosum/fastsurfer_cc.py | 7 +++++-- CorpusCallosum/shape/contour.py | 5 +++-- CorpusCallosum/shape/postprocessing.py | 5 ++++- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 55473043..63374c97 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -180,7 +180,7 @@ def _set_help_sid(action): ) parser.add_argument( "--contour_smoothing", - type=float, + type=int, default=5, help="Gaussian sigma for smoothing during contour detection. Higher values mean a smoother CC outline, at the " "cost of precision.", @@ -877,6 +877,7 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: # save segmentation labels, this if sd.has_attribute("cc_segmentation"): + sd.filename_by_attribute("cc_segmentation").parent.mkdir(exist_ok=True, parents=True) io_futures.append(thread_executor().submit( nib.save, nib.MGHImage(cc_fn_seg_labels, fsaverage_midslab_vox2ras, orig.header), @@ -956,6 +957,7 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: if sd.has_attribute("cc_mid_measures"): + sd.filename_by_attribute('cc_mid_measures').parent.mkdir(exist_ok=True, parents=True) io_futures.append(thread_executor().submit( save_cc_measures_json, sd.filename_by_attribute('cc_mid_measures'), @@ -963,6 +965,7 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: )) if sd.has_attribute("cc_measures"): + sd.filename_by_attribute("cc_measures").parent.mkdir(exist_ok=True, parents=True) io_futures.append(thread_executor().submit( save_cc_measures_json, sd.filename_by_attribute("cc_measures"), @@ -972,7 +975,7 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: # save lta to fsaverage space if sd.has_attribute("upright_lta"): - sd.filename_by_attribute("cc_mid_measures").parent.mkdir(exist_ok=True, parents=True) + sd.filename_by_attribute("upright_lta").parent.mkdir(exist_ok=True, parents=True) logger.info(f"Saving LTA to fsaverage space: {sd.filename_by_attribute('upright_lta')}") io_futures.append(thread_executor().submit( write_lta, diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py index 5bfbf21d..8119ceb1 100644 --- a/CorpusCallosum/shape/contour.py +++ b/CorpusCallosum/shape/contour.py @@ -773,10 +773,11 @@ def from_mask_and_acpc( _contour: Points2dType = skimage.measure.find_contours(cc_mask, level=0.5)[0] # FIXME: maybe CCContour should just inherit from Polygon? + # remove last, duplicate point + _contour = _contour[:-1] polygon = lapy.polygon.Polygon(np.concatenate([np.zeros_like(_contour[:, :1]), _contour], axis=1), closed=True) polygon.smooth_laplace(n=contour_smoothing, inplace=True) - polygon.resample(701, inplace=True) - + polygon.resample(700, inplace=True) contour_ras = apply_affine(slice_vox2ras, polygon.points) ac_pc_3d = np.concatenate([[[0], [0]], np.stack([ac_2d, pc_2d], axis=0)], axis=1) # (2, 3) diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index 14bc3999..2c2f7e1c 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -581,7 +581,10 @@ def make_subdivision_mask( coords_vox = np.stack(np.mgrid[0:1, 0:rows, 0:cols], axis=-1) coords_ras = apply_affine(vox2ras, coords_vox) - cc_labels_posterior_to_anterior = SUBSEGMENT_LABELS + # Use only as many labels as needed based on the number of subdivisions + # Number of regions = number of division lines + 1 + num_labels_needed = len(subdivision_segments) + 1 + cc_labels_posterior_to_anterior = SUBSEGMENT_LABELS[:num_labels_needed] # Initialize with first segment label subdivision_mask = np.full(slice_shape, cc_labels_posterior_to_anterior[0], dtype=np.int32) From b3b0de9ecf26513998d94c1f688624265d06ab47 Mon Sep 17 00:00:00 2001 From: ClePol Date: Tue, 23 Dec 2025 13:44:10 +0100 Subject: [PATCH 60/68] added advanced curvature metrics and refactored curvature calculation --- CorpusCallosum/fastsurfer_cc.py | 2 + CorpusCallosum/shape/contour.py | 2 +- CorpusCallosum/shape/curvature.py | 129 +++++++++++++++++++++ CorpusCallosum/shape/endpoint_heuristic.py | 2 +- CorpusCallosum/shape/postprocessing.py | 67 ++++------- CorpusCallosum/shape/subsegment_contour.py | 77 +++++++++++- CorpusCallosum/shape/thickness.py | 27 +---- CorpusCallosum/utils/types.py | 4 +- 8 files changed, 235 insertions(+), 75 deletions(-) create mode 100644 CorpusCallosum/shape/curvature.py diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index 63374c97..b9f07d91 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -93,6 +93,8 @@ "total_area", "total_perimeter", "thickness_profile", + "curvature_subsegments", + "curvature_body", ] diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py index 8119ceb1..5cb5eae4 100644 --- a/CorpusCallosum/shape/contour.py +++ b/CorpusCallosum/shape/contour.py @@ -775,7 +775,7 @@ def from_mask_and_acpc( # FIXME: maybe CCContour should just inherit from Polygon? # remove last, duplicate point _contour = _contour[:-1] - polygon = lapy.polygon.Polygon(np.concatenate([np.zeros_like(_contour[:, :1]), _contour], axis=1), closed=True) + polygon = lapy.Polygon(np.concatenate([np.zeros_like(_contour[:, :1]), _contour], axis=1), closed=True) polygon.smooth_laplace(n=contour_smoothing, inplace=True) polygon.resample(700, inplace=True) contour_ras = apply_affine(slice_vox2ras, polygon.points) diff --git a/CorpusCallosum/shape/curvature.py b/CorpusCallosum/shape/curvature.py new file mode 100644 index 00000000..daaeebac --- /dev/null +++ b/CorpusCallosum/shape/curvature.py @@ -0,0 +1,129 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from CorpusCallosum.utils.types import ContourList, Points2dType + + +def compute_curvature(path: Points2dType) -> np.ndarray[tuple[int], np.dtype[np.float_]]: + """Compute curvature by computing edge angles. + + Parameters + ---------- + path : np.ndarray + Array of shape (N, 2) containing path coordinates. + + Returns + ------- + np.ndarray + Array of angle differences between consecutive edges. + """ + # compute curvature by computing edge angles + edges = np.diff(path, axis=0) + angles = np.arctan2(edges[:, 1], edges[:, 0]) + # compute angle differences between consecutive edges + angle_diffs = np.diff(angles) + # wrap angles to [-pi, pi] + angle_diffs = np.mod(angle_diffs + np.pi, 2 * np.pi) - np.pi + return angle_diffs + + +def compute_mean_curvature(path: Points2dType) -> float: + """Compute mean curvature of a path. + + Parameters + ---------- + path : np.ndarray + Array of shape (N, 2) containing path coordinates. + + Returns + ------- + float + Mean curvature of the path. + """ + curvature = compute_curvature(path) + if len(curvature) == 0: + return 0.0 + return np.abs(np.degrees(np.mean(curvature))).item() / len(curvature) + + +def calculate_curvature_metrics( + midline: Points2dType, + split_points: np.ndarray | None = None, + split_contours: ContourList | None = None, +) -> tuple[float, float, np.ndarray]: + """ + Calculate curvature metrics for the CC midline, including overall mean, + body (central 65%), and subsegment curvatures. + + Parameters + ---------- + midline : Points2dType + Equidistant points along the midline. + split_points : np.ndarray, optional + Points on the midline where it was split (for orthogonal subdivision). + split_contours : ContourList, optional + List of split contours (for other subdivision methods). + + Returns + ------- + mean_curvature : float + Overall mean curvature. + curvature_body : float + Mean curvature of the central 65% of the midline. + curvature_subsegments : np.ndarray + Mean curvature for each subsegment. + """ + mean_curvature = compute_mean_curvature(midline) + + num_midline_points = len(midline) + # central 65% means we remove 17.5% from each end + start_idx_body = int(num_midline_points * 0.175) + end_idx_body = int(num_midline_points * 0.825) + curvature_body = compute_mean_curvature(midline[start_idx_body:end_idx_body]) + + # Find split indices on the midline for subsegment curvature + split_indices_midline = [0] + if split_points is not None: + for sp in split_points: + idx = np.argmin(np.linalg.norm(midline - sp, axis=1)) + split_indices_midline.append(idx) + elif split_contours is not None: + from CorpusCallosum.shape.subsegment_contour import get_unique_contour_points + unique_points = get_unique_contour_points(split_contours) + for line_pts in unique_points[1:]: + if len(line_pts) == 2: + # find where this line crosses the midline + # use the average of the two points and find closest point on midline + mid_pt = np.mean(line_pts, axis=0) + idx = np.argmin(np.linalg.norm(midline - mid_pt, axis=1)) + split_indices_midline.append(idx) + + split_indices_midline.append(len(midline) - 1) + split_indices_midline.sort() + + _curvature_subsegments = [] + for i in range(len(split_indices_midline) - 1): + s_idx = split_indices_midline[i] + e_idx = split_indices_midline[i + 1] + if e_idx - s_idx >= 2: # need at least 3 points for curvature + curv = compute_mean_curvature(midline[s_idx : e_idx + 1]) + else: + curv = 0.0 + _curvature_subsegments.append(curv) + curvature_subsegments = np.asarray(_curvature_subsegments) + + return mean_curvature, curvature_body, curvature_subsegments + diff --git a/CorpusCallosum/shape/endpoint_heuristic.py b/CorpusCallosum/shape/endpoint_heuristic.py index d6e84227..ddb5af4b 100644 --- a/CorpusCallosum/shape/endpoint_heuristic.py +++ b/CorpusCallosum/shape/endpoint_heuristic.py @@ -146,7 +146,7 @@ def extract_cc_contour(cc_mask: Mask2d, contour_smoothing: int = 5) -> Polygon2d Returns ------- - lapy.polygon.Polygon + lapy.Polygon A lapy Polygon object with a closed polygon contour. """ cc_mask = connect_diagonally_connected_components(cc_mask) diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index 2c2f7e1c..86d2653b 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -24,19 +24,27 @@ import FastSurferCNN.utils.logging as logging from CorpusCallosum.data.constants import CC_LABEL, SUBSEGMENT_LABELS from CorpusCallosum.shape.contour import CCContour +from CorpusCallosum.shape.curvature import calculate_curvature_metrics from CorpusCallosum.shape.endpoint_heuristic import connect_diagonally_connected_components from CorpusCallosum.shape.mesh import CCMesh from CorpusCallosum.shape.metrics import calculate_cc_index from CorpusCallosum.shape.subsegment_contour import ( ContourList, get_primary_eigenvector, + get_unique_contour_points, hampel_subdivide_contour, subdivide_contour, subsegment_midline_orthogonal, transform_to_acpc_standard, ) from CorpusCallosum.shape.thickness import cc_thickness -from CorpusCallosum.utils.types import CCMeasuresDict, ContourThickness, Points2dType, SliceSelection, SubdivisionMethod +from CorpusCallosum.utils.types import ( + CCMeasuresDict, + ContourThickness, + Points2dType, + SliceSelection, + SubdivisionMethod, +) from CorpusCallosum.utils.visualization import plot_contours from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Mask2d, Shape2d, Shape3d, Vector2d from FastSurferCNN.utils.common import SubjectDirectory, update_docstring @@ -411,11 +419,16 @@ def recon_cc_surf_measure( # Apply different subdivision methods based on user choice split_contours: ContourList + split_points_midline: np.ndarray | None = None if subdivision_method == "shape": _subdivisions = np.asarray(subdivisions) - areas, split_contours = subsegment_midline_orthogonal(midline_equi, _subdivisions, contour_as, plot=False) - split_contours = [transform_to_acpc_standard(split_contour, *acpc_contour_coords_as)[0] - for split_contour in split_contours] + areas, split_contours, split_points_midline = subsegment_midline_orthogonal( + midline_equi, _subdivisions, contour_as, plot=False + ) + split_contours = [ + transform_to_acpc_standard(split_contour, *acpc_contour_coords_as)[0] + for split_contour in split_contours + ] elif subdivision_method == "vertical": areas, split_contours = subdivide_contour(contour_in_acpc_space, subdivisions, plot=False) elif subdivision_method == "angular": @@ -442,6 +455,11 @@ def recon_cc_surf_measure( # Transform split contours back to original space split_contours = [rotate_back_acpc(split_contour) for split_contour in split_contours] + # Calculate curvature metrics + curvature, curvature_body, curvature_subsegments = calculate_curvature_metrics( + midline_equi, split_points=split_points_midline, split_contours=split_contours + ) + measures: CCMeasuresDict = { "cc_index": cc_index, "circularity": circularity, @@ -449,6 +467,8 @@ def recon_cc_surf_measure( "midline_length": midline_len, "thickness": thickness, "curvature": curvature, + "curvature_subsegments": curvature_subsegments, + "curvature_body": curvature_body, "thickness_profile": thickness_profile, "total_area": total_area, "total_perimeter": total_perimeter, @@ -494,45 +514,6 @@ def test_right_of_line( return np.greater(cross_products, 0) -def get_unique_contour_points(split_contours: ContourList) -> list[Points2dType]: - """Get unique contour points from the split contours. - - Parameters - ---------- - split_contours : ContourList - List of split contours (subsegmentations), each containing x and y coordinates, each of shape (2, N). - - Returns - ------- - list[np.ndarray] - List of unique contour points for each subsegment, each of shape (N, 2). - - Notes - ----- - This is a workaround to retrospectively add voxel-based subdivision. - In the future, we could keep track of the subdivision lines for - every subdivision scheme. - - The function: - 1. Processes each contour point. - 2. Checks if it appears in other contours (with small tolerance). - 3. Collects points unique to each subsegment. - """ - # For each contour point, check if it appears in other contours - # initialize with values for first_contour, which are by definition just "the contour" (empty) - unique_contour_points: list[Points2dType] = [np.zeros((0, 2))] - first_contour = split_contours[0] - # Check each point against all other contours - for contour in split_contours[1:]: - # 0: coord-axis, 1: contour-axis, 2: first_contour_axis - contour_comparison = np.isclose(first_contour[:, None], contour[:, :, None], atol=1e-6) - # mask of contour points, that are also in first_contour (axis 1 after all) - contour_points_in_first_contour_mask = np.any(np.all(contour_comparison, axis=0), axis=1) - unique_contour_points.append(contour[:, ~contour_points_in_first_contour_mask].T) - - return unique_contour_points - - def make_subdivision_mask( slice_shape: Shape2d, split_contours: ContourList, diff --git a/CorpusCallosum/shape/subsegment_contour.py b/CorpusCallosum/shape/subsegment_contour.py index ea345744..b7da161b 100644 --- a/CorpusCallosum/shape/subsegment_contour.py +++ b/CorpusCallosum/shape/subsegment_contour.py @@ -110,7 +110,7 @@ def subsegment_midline_orthogonal( plot: bool = True, ax=None, extremes=None, -) -> tuple[np.ndarray[tuple[int], np.dtype[ScalarType]], ContourList]: +) -> tuple[np.ndarray[tuple[int], np.dtype[ScalarType]], ContourList, np.ndarray]: """Subsegment contour orthogonally to the midline based on area weights. Parameters @@ -134,6 +134,8 @@ def subsegment_midline_orthogonal( List of subsegment areas. split_contours : list of np.ndarray List of contour arrays for each subsegment. + split_points : np.ndarray + Array of shape (K, 2) containing points where the midline was split. """ # FIXME: Here and in other places, the order of dimensions is pretty inconsistent, for example: midline is (N, 2), # but contours are (2, N)... @@ -365,12 +367,81 @@ def subsegment_midline_orthogonal( ax.axis("equal") plt.show() - return calc_subsegment_areas(split_contours), split_contours + return calc_subsegment_areas(split_contours), split_contours, split_points + + +def get_unique_contour_points(split_contours: ContourList) -> list[Points2dType]: + """Get unique contour points from the split contours. + + Parameters + ---------- + split_contours : ContourList + List of split contours (subsegmentations), each containing x and y coordinates, each of shape (2, N). + + Returns + ------- + list[np.ndarray] + List of unique contour points for each subsegment, each of shape (N, 2). + + Notes + ----- + This is a workaround to retrospectively add voxel-based subdivision. + In the future, we could keep track of the subdivision lines for + every subdivision scheme. + + The function: + 1. Processes each contour point. + 2. Checks if it appears in other contours (with small tolerance). + 3. Collects points unique to each subsegment. + """ + # For each contour point, check if it appears in other contours + # initialize with values for first_contour, which are by definition just "the contour" (empty) + unique_contour_points: list[Points2dType] = [np.zeros((0, 2))] + first_contour = split_contours[0] + # Check each point against all other contours + for contour in split_contours[1:]: + # 0: coord-axis, 1: contour-axis, 2: first_contour_axis + contour_comparison = np.isclose(first_contour[:, None], contour[:, :, None], atol=1e-6) + # mask of contour points, that are also in first_contour (axis 1 after all) + contour_points_in_first_contour_mask = np.any(np.all(contour_comparison, axis=0), axis=1) + unique_contour_points.append(contour[:, ~contour_points_in_first_contour_mask].T) + + return unique_contour_points def hampel_subdivide_contour(contour: Polygon2dType, num_rays: int, plot: bool = False, ax=None) \ -> tuple[np.ndarray[tuple[int], np.dtype[np.float_]], ContourList]: - # FIXME: needs docstring + """Subdivide contour based on area weights using equally spaced rays. + + Parameters + ---------- + contour : np.ndarray + Array of shape (2, N) containing contour points. + num_rays : int + Number of rays to use for subdivision. + plot : bool, optional + Whether to plot the results, by default False. + ax : matplotlib.axes.Axes, optional + Axes for plotting, by default None. + + Returns + ------- + areas : np.ndarray + Array of areas for each subsegment. + split_contours : list[np.ndarray] + List of contour arrays for each subsegment. + + Notes + ----- + The subdivision process: + 1. Finds extreme points in x-direction. + 2. Creates minimal bounding rectangle around contour. + 3. Creates equally spaced rays from lower edge of rectangle. + 4. Finds intersections of rays with contour. + 5. Creates new contours by splitting at intersections. + 6. Returns areas and split contours. + """ + # Find the extreme points in the x-direction min_x_index = np.argmin(contour[0]) contour = np.roll(contour, -min_x_index, axis=1) diff --git a/CorpusCallosum/shape/thickness.py b/CorpusCallosum/shape/thickness.py index 7c1c46d3..def2ca37 100644 --- a/CorpusCallosum/shape/thickness.py +++ b/CorpusCallosum/shape/thickness.py @@ -19,33 +19,11 @@ from lapy.diffgeo import compute_rotated_f from meshpy import triangle +from CorpusCallosum.shape.curvature import compute_mean_curvature from CorpusCallosum.utils.types import ContourThickness, Points2dType from FastSurferCNN.utils.common import suppress_stdout -def compute_curvature(path: Points2dType) -> np.ndarray[tuple[int], np.dtype[np.float_]]: - """Compute curvature by computing edge angles. - - Parameters - ---------- - path : np.ndarray - Array of shape (N, 2) containing path coordinates. - - Returns - ------- - np.ndarray - Array of angle differences between consecutive edges. - """ - # compute curvature by computing edge angles - edges = np.diff(path, axis=0) - angles = np.arctan2(edges[:, 1], edges[:, 0]) - # compute angle differences between consecutive edges - angle_diffs = np.diff(angles) - # wrap angles to [-pi, pi] - angle_diffs = np.mod(angle_diffs + np.pi, 2 * np.pi) - np.pi - return angle_diffs - - def set_contour_zero_idx(contour, idx, anterior_endpoint_idx, posterior_endpoint_idx): """Roll contour points to set a new zero index, while keeping track of CC endpoints. @@ -352,8 +330,7 @@ def cc_thickness( contour_2d_with_thickness = np.concatenate([contour_2d, contour_thickness[:, None]], axis=1) # get curvature of path3d_resampled - curvature = compute_curvature(midline_equidistant_contour_space) - mean_curvature: float = np.abs(np.degrees(np.mean(curvature))).item() / len(curvature) + mean_curvature: float = compute_mean_curvature(midline_equidistant_contour_space) mean_thickness: float = np.mean(levelpath_lengths).item() endpoints: tuple[int, int] = (anterior_endpoint_idx, posterior_endpoint_idx) diff --git a/CorpusCallosum/utils/types.py b/CorpusCallosum/utils/types.py index 52b45f9c..26d0aee8 100644 --- a/CorpusCallosum/utils/types.py +++ b/CorpusCallosum/utils/types.py @@ -66,9 +66,9 @@ class CCMeasuresDict(TypedDict): thickness_profile: ndarray[tuple[int], dtype[float]] total_area: float total_perimeter: float - total_area: float - total_perimeter: float split_contours: ContourList midline_equidistant: ndarray + curvature_subsegments: ndarray + curvature_body: float levelpaths: list[ndarray] slice_index: int From a222374966b2eca1852479b9b0cda1d69f2d2db0 Mon Sep 17 00:00:00 2001 From: David Kuegler Date: Wed, 24 Dec 2025 15:29:51 +0100 Subject: [PATCH 61/68] bump lapy to 1.5.0 (instead of install from github) FastSurfer is not compatible with numpy 2 right now, because numpy 2 has changed some of the names of datatypes. --- env/fastsurfer.yml | 2 +- pyproject.toml | 4 ++-- requirements.mac.txt | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/env/fastsurfer.yml b/env/fastsurfer.yml index eb637803..d8bc1323 100644 --- a/env/fastsurfer.yml +++ b/env/fastsurfer.yml @@ -5,6 +5,7 @@ channels: dependencies: - h5py==3.12.1 +- lapy==1.5.0 - matplotlib==3.10.1 - monai==1.4.0 - nibabel==5.3.2 @@ -25,7 +26,6 @@ dependencies: - yacs==0.1.8 - pip: - --extra-index-url https://download.pytorch.org/whl/cu126 - - lapy=git+https://github.com/Deep-MI/lapy.git@main - simpleitk==2.4.1 - torch==2.6.0+cu126 - torchio==0.20.4 diff --git a/pyproject.toml b/pyproject.toml index c73b15f3..52b1576c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,10 +33,10 @@ classifiers = [ ] dependencies = [ 'h5py>=3.7', - "lapy @ git+https://github.com/Deep-MI/lapy.git@main", + "lapy>=1.5.0", 'matplotlib>=3.7.1', 'nibabel>=5.1.0', - 'numpy>=1.25', + 'numpy>=1.25,<2', 'pandas>=1.5.3', 'pyyaml>=6.0', 'requests>=2.31.0', diff --git a/requirements.mac.txt b/requirements.mac.txt index 5f8775c6..fc713c6a 100644 --- a/requirements.mac.txt +++ b/requirements.mac.txt @@ -1,5 +1,5 @@ h5py>=3.7 -lapy>=1.0.1 +lapy>=1.5.0 matplotlib>=3.7.1 nibabel>=5.1.0 numpy>=1.25,<2 From b536b2fa5153346ff243b748d3cb107eee5dd81f Mon Sep 17 00:00:00 2001 From: ClePol Date: Mon, 5 Jan 2026 17:32:44 +0100 Subject: [PATCH 62/68] adressed fixmes, fixed bugs, added warnings --- CorpusCallosum/cc_visualization.py | 4 +- CorpusCallosum/data/fsaverage_cc_template.py | 16 +- CorpusCallosum/fastsurfer_cc.py | 58 +++-- .../segmentation_postprocessing.py | 91 +------- CorpusCallosum/shape/contour.py | 126 +++++++++-- CorpusCallosum/shape/endpoint_heuristic.py | 207 +++--------------- CorpusCallosum/shape/mesh.py | 62 ++---- CorpusCallosum/shape/postprocessing.py | 83 +++---- CorpusCallosum/shape/subsegment_contour.py | 104 +++------ CorpusCallosum/utils/mapping_helpers.py | 23 +- 10 files changed, 293 insertions(+), 481 deletions(-) diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index aafa2966..9027093d 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -152,7 +152,7 @@ def load_contours_from_template_dir( if fsaverage_contour is None: fsaverage_contour = load_fsaverage_cc_template() # create measurement points (points = 2 x levelpaths) according to number of thickness values - fsaverage_contour.create_levelpaths(num_points=num_thickness_values // 2, update_data=True) + fsaverage_contour.create_levelpaths(num_points=num_thickness_values // 2, inplace=True) current_contour = fsaverage_contour.copy() current_contour.z_position = z_position current_contour.load_thickness_values(thickness_file) @@ -208,7 +208,6 @@ def main( return 0 # 3D visualization - # FIXME: This function would need contours[i].z_position to be properly initialized! cc_mesh = CCMesh.from_contours(contours, smooth=0) plot_kwargs = dict( @@ -220,7 +219,6 @@ def main( cc_mesh.plot_mesh(**plot_kwargs) cc_mesh.plot_mesh(output_path=str(output_dir / "cc_mesh.html"), **plot_kwargs) - #FIXME: needs to be adapted to new interface of CCMesh.to_fs_coordinates / to_vox_coordinates # Here we need to load the np.linalg.inv(fsavg_vox2ras @ orig2fsavg_vox2vox) # This is the same as orig2fsavg_ras2ras from cc_up.lta # orig2fsavg_ras2ras = read_lta(output_dir / "mri/transforms/cc_up.lta") diff --git a/CorpusCallosum/data/fsaverage_cc_template.py b/CorpusCallosum/data/fsaverage_cc_template.py index f52a849a..0b67b767 100644 --- a/CorpusCallosum/data/fsaverage_cc_template.py +++ b/CorpusCallosum/data/fsaverage_cc_template.py @@ -120,7 +120,7 @@ def load_fsaverage_cc_template() -> CCContour: # Use the smoothed mask for further processing cc_mask = cc_mask_smoothed.astype(int) * 192 - _, contour_with_thickness, (anterior_endpoint_idx, posterior_endpoint_idx) = recon_cc_surf_measure( + _, _fsaverage_contour = recon_cc_surf_measure( segmentation=cc_mask[None], slice_idx=0, ac_coords_vox=FSAVERAGE_AC_COORDINATE, @@ -130,13 +130,13 @@ def load_fsaverage_cc_template() -> CCContour: subdivisions=[1/6, 1/2, 2/3, 3/4], subdivision_method="shape", contour_smoothing=5, - vox_size=(1., 1., 1.), # fsaverage is in 1mm isotropic ) - outside_contour = contour_with_thickness[:,:2].T + outside_contour = _fsaverage_contour.points.T + anterior_endpoint_idx, posterior_endpoint_idx = _fsaverage_contour.endpoint_idxs # make sure the CC stays in shape despite smoothing by moving endpoints outwards - outside_contour[0,anterior_endpoint_idx] -= 55 - outside_contour[0,posterior_endpoint_idx] += 30 + outside_contour[0, anterior_endpoint_idx] -= 55 + outside_contour[0, posterior_endpoint_idx] += 30 # Apply smoothing to the outside contour outside_contour_smoothed = smooth_contour(outside_contour, window_size=11) @@ -144,9 +144,9 @@ def load_fsaverage_cc_template() -> CCContour: outside_contour_smoothed = smooth_contour(outside_contour_smoothed, window_size=30) outside_contour = outside_contour_smoothed - fsaverage_contour = CCContour(np.array(outside_contour).T, - np.zeros(len(outside_contour[0])), - endpoint_idxs=(anterior_endpoint_idx, posterior_endpoint_idx), + fsaverage_contour = CCContour(np.array(outside_contour).T, + np.zeros(len(outside_contour[0])), + endpoint_idxs=(anterior_endpoint_idx, posterior_endpoint_idx), z_position=0.0) diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py index b9f07d91..7de15e71 100644 --- a/CorpusCallosum/fastsurfer_cc.py +++ b/CorpusCallosum/fastsurfer_cc.py @@ -46,6 +46,7 @@ from CorpusCallosum.localization import inference as localization_inference from CorpusCallosum.segmentation import inference as segmentation_inference from CorpusCallosum.segmentation import segmentation_postprocessing +from CorpusCallosum.shape.contour import calculate_volume as calculate_cc_volume_contour from CorpusCallosum.shape.postprocessing import ( check_area_changes, make_subdivision_mask, @@ -840,7 +841,7 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: # Process slices based on selection mode logger.info(f"Processing slices with selection mode: {slice_selection}") - slice_results, slice_io_futures = recon_cc_surf_measures_multi( + slice_results, slice_io_futures, cc_contours, cc_mesh = recon_cc_surf_measures_multi( segmentation=cc_fn_seg_labels, slice_selection=slice_selection, upright_header=fsavg_header, @@ -867,15 +868,7 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: # Get middle slice result middle_slice_result: CCMeasuresDict = slice_results[len(slice_results) // 2] - if len(middle_slice_result["split_contours"]) <= 5: - cc_subseg_midslice = make_subdivision_mask( - (cc_fn_seg_labels.shape[1], cc_fn_seg_labels.shape[2]), - middle_slice_result["split_contours"], - vox2ras=fsavg_vox2ras @ np.linalg.inv(fsavg2midslice_vox2vox), - ) - else: - logger.warning("Too many subsegments for lookup table, skipping sub-division of output segmentation.") - cc_subseg_midslice = None + # save segmentation labels, this if sd.has_attribute("cc_segmentation"): @@ -887,6 +880,15 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: )) # map soft labels to original space (in parallel because this takes a while, and we only do it to save the labels) if sd.has_attribute("cc_orig_segfile"): + if len(middle_slice_result["split_contours"]) <= 5: + cc_subseg_midslice = make_subdivision_mask( + (cc_fn_seg_labels.shape[1], cc_fn_seg_labels.shape[2]), + middle_slice_result["split_contours"], + vox2ras=fsavg_vox2ras @ np.linalg.inv(fsavg2midslice_vox2vox), + ) + else: + logger.warning("Too many subsegments for lookup table, skipping sub-division of output segmentation.") + cc_subseg_midslice = None # if num_threads is not large enough (>1), this might be blocking ; serial_executor runs the function in submit executor = thread_executor() if get_num_threads() > 2 else serial_executor() io_futures.append(executor.submit( @@ -919,18 +921,8 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: voxel_size=vox_size, # in LIA order ) logger.info(f"CC volume voxel: {cc_volume_voxel}") - # FIXME: Create a proper mesh and use cc_mesh.volume for this volume --> not closed, but move function to - # CCContour? - try: - cc_volume_contour = segmentation_postprocessing.get_cc_volume_contour( - cc_contours=outer_contours, - voxel_size=vox_size, # in LIA order - ) - logger.info(f"CC volume contour: {cc_volume_contour}") - except AssertionError as e: - logger.warning("Could not compute CC volume from contours, setting to NaN") - logger.exception(e) - cc_volume_contour = float('nan') + cc_volume_contour = calculate_cc_volume_contour(cc_contours, width=5.0) + logger.info(f"CC volume contour: {cc_volume_contour}") additional_metrics["cc_5mm_volume"] = cc_volume_voxel additional_metrics["cc_5mm_volume_pv_corrected"] = cc_volume_contour @@ -957,6 +949,28 @@ def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: additional_metrics["contour_smoothing"] = contour_smoothing additional_metrics["slice_selection"] = slice_selection + # QC checks + if len(outer_contours) > 1: + max_vol = max(cc_volume_voxel, cc_volume_contour) + if max_vol > 0 and abs(cc_volume_voxel - cc_volume_contour) / max_vol > 0.2: + logger.warning( + f"QC flag: CC volume estimates differ by more than 20% " + f"(voxel: {cc_volume_voxel:.2f}, contour: {cc_volume_contour:.2f})", + "this can happen if contour creation failed for some slices" + ) + + cc_index = output_metrics_middle_slice.get("cc_index") + if cc_index is not None and cc_index > 2: + logger.warning( + f"QC flag: CC index is high ({cc_index:.2f} > 2), segmentation or contour creation may be incorrect" + ) + + midline_length = output_metrics_middle_slice.get("midline_length") + if midline_length is not None and midline_length < 30: + logger.warning( + f"QC flag: CC midline length is short ({midline_length:.2f}mm < 30mm), endpoints may be " + "incorrectly detected or contour creation may have failed" + ) if sd.has_attribute("cc_mid_measures"): sd.filename_by_attribute('cc_mid_measures').parent.mkdir(exist_ok=True, parents=True) diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index 68b238fb..c8ee2d99 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -14,7 +14,7 @@ import numpy as np from numpy import typing as npt -from scipy import integrate, ndimage +from scipy import ndimage from scipy.spatial.distance import cdist from skimage.measure import label @@ -335,7 +335,9 @@ def get_cc_volume_voxel( elif width_mm > desired_width_mm: # remainder on the left/right side of the CC mask desired_width_vox = desired_width_mm / voxel_width - fraction_of_voxel_at_edge = (desired_width_vox % 1) / 2 + # The number of full voxels in the center is (cc_mask.shape[0] - 2) + # The remaining width must be covered by the two edge voxels. + fraction_of_voxel_at_edge = (desired_width_vox - (cc_mask.shape[0] - 2)) / 2 if fraction_of_voxel_at_edge > 0: # make sure the assumentation is correct that the CC mask has an odd number of voxels @@ -354,91 +356,6 @@ def get_cc_volume_voxel( raise ValueError(f"Width of CC segmentation is smaller than desired width: {width_mm} < {desired_width_mm}") -def get_cc_volume_contour( - cc_contours: list[np.ndarray], - voxel_size: tuple[float, float, float], -) -> float: - """Calculate the volume of the corpus callosum using Simpson's rule. - - Parameters - ---------- - cc_contours : list[np.ndarray] - List of CC contours for each slice in the left-right direction. - voxel_size : triplet of floats - Voxel size in millimeters (x, y, z). - - Returns - ------- - float - Volume of the CC in cubic millimeters. - - Raises - ------ - ValueError - If CC width is smaller than desired width or insufficient contours for Simpson's rule - - Notes - ----- - This function calculates the volume of the corpus callosum (CC) in cubic millimeters - using Simpson's rule. If the CC width is larger than desired_width_mm, the voxels on - the edges are calculated as partial volumes to achieve the desired width. - """ - # FIXME: move to CCContour --> area - - - # FIXME: this code currently produces volume estimates more that 50% off of the volume_based estimate in - # get_cc_volume_voxel... - - if len(cc_contours) < 3: - raise ValueError("Need at least 3 contours for Simpson's rule integration") - - # FIXME: why can we not multiply by those numbers in line below other FIXME comment - # converting this to a warning for now... - if voxel_size[1] == voxel_size[2]: - logger.warning("voxel sizes in get_cc_volume_contour, currently volume must be isotropic!") - # Calculate cross-sectional areas for each contour - areas = [] - for contour in cc_contours: - # Calculate area using the shoelace formula for polygon area - if contour.shape[1] < 3: - areas.append(0.0) - else: - # FIXME: we are multiplying by voxel size here and below "Convert from voxel^2 to mm^2", e.g. - # x = contour[0] * voxel_size[1] - # y = contour[1] * voxel_size[2] - contour = contour * voxel_size[1] - x = contour[0] - y = contour[1] - # Shoelace formula: A = 0.5 * |sum(x_i * y_{i+1} - x_{i+1} * y_i)| - area = 0.5 * np.abs(np.sum(x[:-1] * y[1:] - x[1:] * y[:-1])) - # Convert from voxel^2 to mm^2 - area_mm2 = area * voxel_size[1] * voxel_size[2] # y * z voxel dimensions - areas.append(area_mm2) - - areas = np.array(areas) - - # Calculate spacing between slices (left-right direction) - lr_spacing = voxel_size[0] # x-direction voxel size - - measurement_points = np.arange(-voxel_size[0]*(areas.shape[0]//2), - voxel_size[0]*((areas.shape[0]+1)//2), lr_spacing) - - # FIXME: why interpolate at 0.25? Also, why do we need interpolation at all? - # interpolate areas at 0.25 and 5 - areas_interpolated = np.interp(x=[-2.5, 2.5], - xp=measurement_points, - fp=areas) - - # remove measurement points that are outside of the desired range - # not sure if this can happen, but let's be safe - outside_range = (measurement_points < -2.5) | (measurement_points > 2.5) - measurement_points = [-2.5] + measurement_points[~outside_range].tolist() + [2.5] - areas = [areas_interpolated[0]] + areas[~outside_range].tolist() + [areas_interpolated[1]] - - - # can also use trapezoidal rule - return integrate.simpson(areas, x=measurement_points) - def extract_largest_connected_component( seg_arr: Mask3d, diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py index 5cb5eae4..61f09b88 100644 --- a/CorpusCallosum/shape/contour.py +++ b/CorpusCallosum/shape/contour.py @@ -47,6 +47,11 @@ Self = TypeVar("Self", bound="CCContour") + + + + + # FIXME: Maybe CCContur should inherit from Polygon at a later date? class CCContour: """A class for representing and manipulating corpus callosum (CC) contours. @@ -123,6 +128,21 @@ def __len__(self) -> int: """Return the number of points on the contour.""" return len(self.points) + @property + def area(self) -> float: + """Calculate the area of the contour using the shoelace formula. + + Returns + ------- + float + The area of the contour. + """ + if len(self.points) < 3: + return 0.0 + x = self.points[:, 0] + y = self.points[:, 1] + return 0.5 * np.abs(np.sum(x * np.roll(y, -1) - np.roll(x, -1) * y)) + def smooth_contour(self, window_size: int = 5) -> None: """Smooth a contour using a moving average filter. @@ -161,28 +181,54 @@ def get_contour_edge_lengths(self) -> np.ndarray: def create_levelpaths( self, num_points: int, - update_data: bool = True - ) -> tuple[list[np.ndarray], float]: - #FIXME: docstring + inplace: bool = False + ) -> tuple[list[np.ndarray], float, float, np.ndarray, np.ndarray, tuple[int, int], float]: + """Calculate thickness and level paths for the CC contour using Laplace equation. + + Parameters + ---------- + num_points : int + Number of points for thickness estimation. + update_data : bool, default=True + Whether to update the contour points and thickness values in place. + + Returns + ------- + levelpaths : list[np.ndarray] + List of level paths across the CC. + thickness : float + Mean thickness of the CC. + midline_len : float + Length of the CC midline. + midline_equi : np.ndarray + Equidistant points along the midline. + contour_with_thickness : np.ndarray + Contour points with thickness information, shape (N, 3). + endpoint_idxs : tuple[int, int] + Indices of the anterior and posterior endpoints on the updated contour. + curvature : float + Mean curvature of the midline. + """ + + # FIXME: cache all these values in CCContour, and invalidate the cache, when either points or endpoint_idxs get + # changed; alternatively, make points and endpoint_idxs read_only (by creating getter-only properties) + # and have all functions that change points or endpoints return a new CCContour object instead. + - # FIXME: cache all these values in CCContour, and invalidate the cache, when either points or endpoint_idxs get - # changed; alternatively, make points and endpoint_idxs read_only (by creating getter-only properties) - # and have all functions that change points or endpoints return a new CCContour object instead. midline_len, thickness, curvature, midline_equi, levelpaths, contour_with_thickness, endpoint_idxs = \ cc_thickness( self.points, self.endpoint_idxs, n_points=num_points, ) - - if update_data: - # FIXME: as an alternative to update_data, use "inplace" ; always return the CCContour object? + + if inplace: self.points = contour_with_thickness[:, :2] - self.thickness_values = contour_with_thickness[:,2] + self.thickness_values = contour_with_thickness[:, 2] self.original_thickness_vertices = np.where(~np.isnan(self.thickness_values))[0] self.endpoint_idxs = endpoint_idxs - return levelpaths, thickness + return levelpaths, thickness, midline_len, midline_equi, contour_with_thickness, endpoint_idxs, curvature def set_thickness_values(self, thickness_values: np.ndarray, use_measurement_points: bool = False) -> None: """Set the thickness values for the contour. @@ -380,7 +426,7 @@ def plot_contour_colorfill( # make points 3D by adding zero points = np.column_stack([points, np.zeros(len(points))]) - levelpaths, *_ = self.create_levelpaths(num_points=len(plot_values)-1, update_data=False) + levelpaths, *_ = self.create_levelpaths(num_points=len(plot_values)-1, inplace=False) outside_contour = self.points.T @@ -418,7 +464,6 @@ def plot_contour_colorfill( continue # make levelpath - # FIXME: this change to lapy.Polygon is untested path = lapy.Polygon(path).resample(1000).points # Extend at the beginning: add point in direction opposite to first segment @@ -772,7 +817,6 @@ def from_mask_and_acpc( _contour: Points2dType = skimage.measure.find_contours(cc_mask, level=0.5)[0] - # FIXME: maybe CCContour should just inherit from Polygon? # remove last, duplicate point _contour = _contour[:-1] polygon = lapy.Polygon(np.concatenate([np.zeros_like(_contour[:, :1]), _contour], axis=1), closed=True) @@ -785,3 +829,57 @@ def from_mask_and_acpc( endpoint_idx = find_cc_endpoints(contour_ras[:, 1:].T, ac_ras[1:], pc_ras[1:]) return cls(contour_ras[:, 1:], None, endpoint_idx, z_position=slice_vox2ras[0, 3]) + + + +def calculate_volume(contours: list[CCContour], width: float = 5.0) -> float: + """Calculate the volume of the corpus callosum. + + This method calculates the volume of a slab of the CC centered on the midplane. + It multiplies the area of each cross-sectional slice by the width it + represents within the slab. It assumes equally spaced contours centered + around the midplane (z=0). + + Parameters + ---------- + width : float, default=5.0 + The width of the slab centered on the midplane to calculate the volume for (in mm). + + Returns + ------- + float + The volume of the CC in cubic millimeters. + """ + if len(contours) < 2: + return 0.0 + + # Group vertices by their LR coordinate (column 0 as created by from_contours) + z_coords = [contour.z_position for contour in contours] + areas = [contour.area for contour in contours] + + contour_widths = np.diff(z_coords) + + # check that all widths are the same + if not np.allclose(contour_widths, contour_widths[0]): + raise ValueError("Contours must be equally spaced to calculate CC volume") + + contour_width_mm = abs(contour_widths[0]) + + # Define the slab boundaries centered on the midplane (z=0) + z_min, z_max = -width / 2.0, width / 2.0 + + volume = 0.0 + for i, z in enumerate(z_coords): + # Each contour represents a slab of contour_width_mm + # centered at its z position. + start = z - contour_width_mm / 2.0 + end = z + contour_width_mm / 2.0 + + # Intersection of [start, end] and [z_min, z_max] + effective_start = max(start, z_min) + effective_end = min(end, z_max) + + effective_width = max(0.0, effective_end - effective_start) + volume += areas[i] * effective_width + + return volume \ No newline at end of file diff --git a/CorpusCallosum/shape/endpoint_heuristic.py b/CorpusCallosum/shape/endpoint_heuristic.py index ddb5af4b..e67dfe3b 100644 --- a/CorpusCallosum/shape/endpoint_heuristic.py +++ b/CorpusCallosum/shape/endpoint_heuristic.py @@ -11,11 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal, overload -import lapy.tria_mesh import numpy as np -import scipy.ndimage import skimage.measure from scipy.ndimage import label @@ -157,156 +154,12 @@ def extract_cc_contour(cc_mask: Mask2d, contour_smoothing: int = 5) -> Polygon2d return contour -@overload -def find_contour_and_endpoints( - cc_mask: Mask2d, - ac_2d: Vector2d, - pc_2d: Vector2d, - resolution: tuple[float, float], - return_coordinates: Literal[True], - contour_smoothing: int = 5 -) -> tuple[np.ndarray, tuple[int, int], tuple[Vector2d, Vector2d]]: ... - - -@overload -def find_contour_and_endpoints( - cc_mask: Mask2d, - ac_2d: Vector2d, - pc_2d: Vector2d, - resolution: tuple[float, float], - return_coordinates: Literal[False] = False, - contour_smoothing: int = 5 -) -> tuple[np.ndarray, tuple[int, int]]: ... - - -def find_contour_and_endpoints( - cc_mask: Mask2d, - ac_2d: Vector2d, - pc_2d: Vector2d, - resolution: tuple[float, float], - return_coordinates: bool = False, - contour_smoothing: int = 5 -): - """Extracts the contour of the CC, rotates to AC-PC alignment, and determines closest points of CC to AC and PC. - - Parameters - ---------- - cc_mask : np.ndarray of shape (H, W) and type bool - Binary mask of the corpus callosum. - ac_2d : np.ndarray of shape (2,) and type float - 2D voxel coordinates of the anterior commissure. - pc_2d : np.ndarray of shape (2,) and type float - 2D voxel coordinates of the posterior commissure. - resolution : pair of floats - Inslice image resolution in mm (inferior/superior and anterior/posterior directions). - return_coordinates : bool, default=False - If True, return endpoint coordinates. - contour_smoothing : int, default=5 - Window size for contour smoothing. - - Returns - ------- - contour_rotated : np.ndarray - The contour in 2d voxel coordinates rotated to AC-PC alignment and with origin at center of image - (axis 0: I->S, axis 1: A->P). - anterior_posterior_point_indices : pair of ints - Indices of anterior and posterior points in the contour. - anterior_posterior_point_coordinates : tuple[np.ndarray, np.ndarray] - Only if return_coordinates is True: Coordinates of anterior and posterior points rotated to AP-PC alignment. - - Notes - ----- - Expects LIA orientation. - """ - image_size = cc_mask.shape - - # Calculate angle between AC-PC line and horizontal using numpy - ac_pc_vector = pc_2d - ac_2d - horizontal_vector = np.array([0, -20]) - # Calculate angle using dot product formula: cos(theta) = (a·b)/(|a||b|) - dot_product = np.dot(ac_pc_vector, horizontal_vector) - norms = np.linalg.norm(ac_pc_vector) * np.linalg.norm(horizontal_vector) - theta = np.arccos(dot_product / norms) - - # Convert symbolic theta to float and convert from radians to degrees - theta_degrees = theta * 180 / np.pi - # FIXME: Why do we rotate the mask before we do the marching squares? Instead of all this weird rotation everywhere - # here, it seems to me it would be the same to just rotate the offsets for the heuristic?! Like so: - # - # rot_matrix_inv = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) - # rotated_pc_2d = pc_2d.astype(float) + rot_matrix_inv @ np.array([10, -5]) / resolution - # rotated_ac_2d = ac_2d.astype(float) + rot_matrix_inv @ np.array([0, 5]) / resolution - # - # contour = extract_cc_contour(cc_mask, contour_smoothing) - # # Add z=0 coordinate to make 3D, then remove it after resampling - # contour_3d = np.vstack([contour, np.zeros(contour.shape[1])]) - # contour_3d = __resample_polygon(contour_3d.T, 701).T - # contour = contour_3d[:2, :-1] - # # find point in contour closest to AC - # ac_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_ac_2d[:, None], axis=0)) - # # find point in contour closest to PC - # pc_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_pc_2d[:, None], axis=0)) - - rotated_cc_mask = scipy.ndimage.rotate(cc_mask, -theta_degrees, order=0, reshape=False) - - # rotate points around center - origin_point = np.array([image_size[0] // 2, image_size[1] // 2]) - - # Create rotation matrix for -theta - rot_matrix = np.array([[np.cos(-theta), -np.sin(-theta)], [np.sin(-theta), np.cos(-theta)]]) - - # Translate points to origin, rotate, then translate back - rotated_pc_2d = rot_matrix @ (pc_2d.astype(float) - origin_point) + origin_point - rotated_ac_2d = rot_matrix @ (ac_2d.astype(float) - origin_point) + origin_point - - # move posterior commisure 5 mm posterior, 10 mm superior - # FIXME: multiplication means moving less for smaller voxels, why not division? - # changed to division, 5 mm / voxel size => number of voxels to move - # ----> CHECK IF THESE VALUES ARE CONFIRMED GOOD IN TESTING - rotated_pc_2d = rotated_pc_2d + np.array([10, -5]) / resolution - - # move anterior commisure 5 mm anterior - rotated_ac_2d = rotated_ac_2d + np.array([0, 5]) / resolution - - contour = extract_cc_contour(rotated_cc_mask, contour_smoothing) - - # Add z=0 coordinate to make 3D, then remove it after resampling - #FIXME: change this to using Polygon class when we upgrade lapy - # IMPORTANT: this is incompatible with lapy Polygon update (lapy 1.5?), however, find_contour_and_endpoints is not - # used any more in favor of CCContour.from_cc_mask - contour_3d = lapy.tria_mesh.TriaMesh._TriaMesh__resample_polygon( - np.append(contour, np.zeros((1, contour.shape[1])), axis=0).T, - 701, - ) - contour = contour_3d.T[:2, :-1] - - # find point in contour closest to AC - ac_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_ac_2d[:, None], axis=0)) - - # find point in contour closest to PC - pc_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_pc_2d[:, None], axis=0)) - - # rotate startpoints to original orientation - origin_point = np.array(origin_point).astype(float) - # Create rotation matrix - rot_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) - - # Translate points to origin, rotate, then translate back - contour_centered = contour - origin_point[:, None] - contour_rotated = (rot_matrix @ contour_centered) + origin_point[:, None] - - if return_coordinates: - start_point_ac, start_point_pc = contour_rotated[:, [ac_startpoint_idx, pc_startpoint_idx]].T - - return contour_rotated, (ac_startpoint_idx, pc_startpoint_idx), (start_point_ac, start_point_pc) - else: - return contour_rotated, (ac_startpoint_idx, pc_startpoint_idx) def find_cc_endpoints( contour: Points2dType, ac_2d: Vector2d, pc_2d: Vector2d, - return_coordinates: bool = False, + plot: bool = False, ): """Extracts the contour of the CC, rotates to AC-PC alignment, and determines closest points of CC to AC and PC. @@ -318,16 +171,12 @@ def find_cc_endpoints( 2D AS coordinates of the anterior commissure in millimeter. pc_2d : np.ndarray of shape (2,) and type float 2D AS coordinates of the posterior commissure in millimeter. - return_coordinates : bool, default=False - If True, return endpoint coordinates. Returns ------- anterior_posterior_point_indices : pair of ints Indices of anterior and posterior points in the contour. - anterior_posterior_point_coordinates : pair of Vector2d - Only if return_coordinates is True: Coordinates of anterior and posterior points, each shape (2,). - + Notes ----- Expects AS orientation of contour, ac_2d, and pc_2d. @@ -344,28 +193,44 @@ def find_cc_endpoints( dot_product = np.dot(ac_pc_vector, horizontal_vector) norms = np.linalg.norm(ac_pc_vector) * np.linalg.norm(horizontal_vector) # The sign of theta is the inverse of ac_pc_vector [ X ] - theta = -np.sign(ac_pc_vector[0]) * np.arccos(dot_product / norms) + theta = np.sign(ac_pc_vector[0]) * np.arccos(dot_product / norms) rot_matrix_inv = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) - # move posterior commisure 5 mm posterior, 10 mm inferior - # FIXME: multiplication means moving less for smaller voxels, why not division? - # changed to division, 5 mm / voxel size => number of voxels to move - # ----> CHECK IF THESE VALUES ARE CONFIRMED GOOD IN TESTING - as_offset_pc = np.array([-5, -10], dtype=float) - rotated_pc_2d = pc_2d.astype(float) + rot_matrix_inv @ as_offset_pc + # move posterior commisure 10 mm inferior, 5 mm posterior + as_offset_pc = np.array([10, -5], dtype=float) + posterior_anchor_2d = pc_2d.astype(float) + rot_matrix_inv @ as_offset_pc # move anterior commisure 5 mm anterior - as_offset_ac = np.array([5, 0], dtype=float) - rotated_ac_2d = ac_2d.astype(float) + rot_matrix_inv @ as_offset_ac + as_offset_ac = np.array([0, 5], dtype=float) + anterior_anchor_2d = ac_2d.astype(float) + rot_matrix_inv @ as_offset_ac # Find the endpoints of the CC shape relative to AC and PC coordinates # find point in contour closest to AC - ac_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_ac_2d[:, None], axis=0)) + ac_startpoint_idx = np.argmin(np.linalg.norm(contour - anterior_anchor_2d[:, None], axis=0)) # find point in contour closest to PC - pc_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_pc_2d[:, None], axis=0)) - - if return_coordinates: - start_point_ac, start_point_pc = contour[:, [ac_startpoint_idx, pc_startpoint_idx]].T - - return (ac_startpoint_idx, pc_startpoint_idx), (start_point_ac, start_point_pc) - else: - return ac_startpoint_idx, pc_startpoint_idx + pc_startpoint_idx = np.argmin(np.linalg.norm(contour - posterior_anchor_2d[:, None], axis=0)) + + + if plot: # interactive debug plot of contour, ac, pc and endpoints + import matplotlib + import matplotlib.pyplot as plt + curr_backend = matplotlib.get_backend() + plt.switch_backend("qtagg") + plt.figure(figsize=(10, 8)) + plt.plot(contour[0, :], contour[1, :], 'b-', label='CC Contour', linewidth=2) + plt.plot(ac_2d[0], ac_2d[1], 'go', markersize=10, label='AC') + plt.plot(pc_2d[0], pc_2d[1], 'ro', markersize=10, label='PC') + plt.plot(anterior_anchor_2d[0], anterior_anchor_2d[1], 'g^', markersize=10, label='Anterior Anchor') + plt.plot(posterior_anchor_2d[0], posterior_anchor_2d[1], 'r^', markersize=10, label='Posterior Anchor') + plt.plot(contour[0, ac_startpoint_idx], contour[1, ac_startpoint_idx], 'g*', markersize=15, label='AC Endpoint') + plt.plot(contour[0, pc_startpoint_idx], contour[1, pc_startpoint_idx], 'r*', markersize=15, label='PC Endpoint') + plt.xlabel('A-S (mm)') + plt.ylabel('I-S (mm)') + plt.title('CC Contour with Endpoints') + plt.legend() + plt.axis('equal') + plt.grid(True, alpha=0.3) + plt.show() + plt.switch_backend(curr_backend) + + + return ac_startpoint_idx, pc_startpoint_idx diff --git a/CorpusCallosum/shape/mesh.py b/CorpusCallosum/shape/mesh.py index 9b072a0a..0f1859db 100644 --- a/CorpusCallosum/shape/mesh.py +++ b/CorpusCallosum/shape/mesh.py @@ -44,7 +44,6 @@ class Matrix44(np.ndarray): def _create_cap( points: np.ndarray, - trias: np.ndarray, contour: CCContour, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Create a cap mesh for one end of the corpus callosum. @@ -73,7 +72,7 @@ def _create_cap( 3. Creates triangles between consecutive level paths 4. Smooths thickness values for visualization """ - levelpaths, thickness_values = contour._create_levelpaths(points, trias) + levelpaths, thickness_values, _, _, _, _, _ = contour.create_levelpaths(num_points=len(points), inplace=False) # Create mesh from level paths level_vertices = [] @@ -125,6 +124,8 @@ def _create_cap( # Convert to numpy arrays level_vertices = np.vstack(level_vertices) + # Add z-coordinate (column 0) to make vertices 3D + level_vertices = np.hstack([np.full((len(level_vertices), 1), contour.z_position), level_vertices]) level_faces = np.vstack(level_faces) level_colors = np.concatenate(level_colors) @@ -221,6 +222,7 @@ def __init__( """ super().__init__(np.vstack(vertices), np.vstack(faces)) self.mesh_vertex_colors = vertex_values + def plot_mesh( self, @@ -429,9 +431,9 @@ def plot_mesh( fig.update_layout( scene=dict( - xaxis=dict(range=ranges[0], **{**axis_config, "title": "AP" if show_grid else ""}), - yaxis=dict(range=ranges[1], **{**axis_config, "title": "SI" if show_grid else ""}), - zaxis=dict(range=ranges[2], **{**axis_config, "title": "LR" if show_grid else ""}), + xaxis=dict(range=ranges[0], **{**axis_config, "title": "LR" if show_grid else ""}), + yaxis=dict(range=ranges[1], **{**axis_config, "title": "AP" if show_grid else ""}), + zaxis=dict(range=ranges[2], **{**axis_config, "title": "SI" if show_grid else ""}), camera=dict(eye=dict(x=1.5, y=1.5, z=1), up=dict(x=0, y=0, z=1)), aspectmode="cube", # Force equal aspect ratio aspectratio=dict(x=1, y=1, z=1), @@ -607,9 +609,9 @@ def smooth_(self, iterations: int = 1) -> None: 2. Applies Laplacian smoothing to x and y coordinates. 3. Restores original z-coordinates to maintain slice structure. """ - z_values = self.v[:, 2] + z_values = self.v[:, 0] super().smooth_(iterations) - self.v[:, 2] = z_values + self.v[:, 0] = z_values @staticmethod @@ -652,27 +654,11 @@ def to_vox_coordinates( from copy import copy new_object = copy(self) - # to LSA - # new_object.v = new_object.v[:, [2, 1, 0]] - # to voxel - # FIXME: why are the vertex positions multiplied by voxel size here? - # removed => for center LR, now dividing by resolution => convert fsaverage middle from mm to vox - # => remove the conversion back to mm in the end - # all other operations are independent of order of operations (distributive) - # v_vox /= vox_size[0] - # center LR - # new_object.v[:, 0] += FSAVERAGE_MIDDLE / self.resolution - # flip SI - # new_object.v[:, 1] = -new_object.v[:, 1] - - #v_vox_test = np.round(v_vox).astype(int) - # tkrRAS = Torig*[C R S 1]' # Torig: mri_info --vox2ras-tkr orig.mgz # https://surfer.nmr.mgh.harvard.edu/fswiki/CoordinateSystems new_object.v = (mesh_ras2vox[:3, :3] @ self.v.T).T + mesh_ras2vox[None, :3, 3] - # new_object.v = (vox2ras_tkr @ np.concatenate([self.v, np.ones((self.v.shape[0], 1))], axis=1).T).T[:, :3] return new_object @update_docstring(parent_doc=TriaMesh.write_fssurf.__doc__) @@ -786,41 +772,39 @@ def from_contours( vertex_values = tmp_mesh.mesh_vertex_colors if closed: - # FIXME: this functionality is untested and not used + # this functionality is untested and not used logger.warning("CCMesh.from_contours(closed=True) is untested and likely has errors.") # Close the mesh by creating caps on both ends - # Left cap (first slice) - use counterclockwise orientation - left_side_points, left_side_trias = make_mesh_from_contour(vertices[: vertex_start_indices[1]][..., :2]) - left_side_points = np.hstack([left_side_points, np.full((len(left_side_points), 1), z_coordinates[0])]) + # Left cap (first slice) + left_side_points, left_side_trias = make_mesh_from_contour(vertices[0][..., 1:]) + left_side_points = np.hstack([np.full((len(left_side_points), 1), z_coordinates[0]), left_side_points]) - # Right cap (last slice) - reverse points for proper orientation - right_side_points, right_side_trias = make_mesh_from_contour(vertices[vertex_start_indices[-1]:][..., :2]) - right_side_points = np.hstack([right_side_points, np.full((len(right_side_points), 1), z_coordinates[-1])]) + # Right cap (last slice) + right_side_points, right_side_trias = make_mesh_from_contour(vertices[-1][..., 1:]) + right_side_points = np.hstack([np.full((len(right_side_points), 1), z_coordinates[-1]), right_side_points]) # color_sides is a legacy visualization option to allow caps to have thickness colors color_sides = True if color_sides: left_side_points, left_side_trias, left_side_colors = _create_cap( - left_side_points, left_side_trias, contours[0] + left_side_points, contours[0] ) right_side_points, right_side_trias, right_side_colors = _create_cap( - right_side_points, right_side_trias, contours[-1] + right_side_points, contours[-1] ) - # reverse right side trias + # reverse right side trias for proper orientation right_side_trias = right_side_trias[:, ::-1] - else: - left_side_colors, right_side_colors = [], [] + vertex_values = np.concatenate([vertex_values, left_side_colors, right_side_colors]) + left_side_trias = left_side_trias + current_index current_index += len(left_side_points) right_side_trias = right_side_trias + current_index current_index += len(right_side_points) - # should this not be a concatenate statements? - vertices = [vertices, left_side_points, right_side_points] - faces = [faces, left_side_trias, right_side_trias] - vertex_values = [vertex_values, left_side_colors, right_side_colors] + vertices.extend([left_side_points, right_side_points]) + faces.extend([left_side_trias, right_side_trias]) return cls(vertices, faces, vertex_values=vertex_values) diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index 86d2653b..eb029fba 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -37,10 +37,8 @@ subsegment_midline_orthogonal, transform_to_acpc_standard, ) -from CorpusCallosum.shape.thickness import cc_thickness from CorpusCallosum.utils.types import ( CCMeasuresDict, - ContourThickness, Points2dType, SliceSelection, SubdivisionMethod, @@ -141,6 +139,10 @@ def recon_cc_surf_measures_multi( List of slice processing results. list of concurrent.futures.Future List of background IO processes. + list of CCContour + List of CC contours. + CCMesh + The CC mesh. (None if no mesh was created) """ slice_cc_measures: list[CCMeasuresDict] = [] io_futures = [] @@ -195,13 +197,9 @@ def _gen_slice2slab_vox2vox(_slice_idx: int) -> AffineMatrix4x4: logger.info(f"Calculating CC measurements for slice {slice_idx+1}{progress}") # unpack values from _results cc_measures: CCMeasuresDict = _results[0] - contour_in_as_space_and_thickness: ContourThickness = _results[1] - endpoint_idxs: tuple[int, int] = _results[2] - contour_in_as_space: Points2dType = contour_in_as_space_and_thickness[:, :2] - thickness_values: np.ndarray[tuple[int], np.dtype[np.float_]] = contour_in_as_space_and_thickness[:, 2] + _contour: CCContour = _results[1] - z_value = this_slice_vox2ras[0, 3] - cc_contours.append(CCContour(contour_in_as_space, thickness_values, endpoint_idxs, z_position=z_value)) + cc_contours.append(_contour) if cc_measures is None: # this should not happen, but just in case logger.warning(f"Slice index {slice_idx+1}{progress} returned result `None`") @@ -292,7 +290,7 @@ def _gen_slice2slab_vox2vox(_slice_idx: int) -> AffineMatrix4x4: logger.error("Error: No valid slices were found for postprocessing") raise ValueError("No valid slices were found for postprocessing") - return slice_cc_measures, io_futures + return slice_cc_measures, io_futures, cc_contours, cc_mesh if len(cc_contours) > 1 else None def _resample_thickness(contour: CCContour) -> CCContour: @@ -303,7 +301,7 @@ def _resample_thickness(contour: CCContour) -> CCContour: def recon_cc_surf_measure( - segmentation: np.ndarray[Shape2d, np.dtype[np.int_]], + segmentation: np.ndarray[Shape3d, np.dtype[np.int_]], slice_idx: int, slice_lia_vox2midslice_ras: AffineMatrix4x4, ac_coords_vox: Vector2d, @@ -312,7 +310,7 @@ def recon_cc_surf_measure( subdivisions: list[float], subdivision_method: SubdivisionMethod, contour_smoothing: int, -) -> tuple[CCMeasuresDict, ContourThickness, tuple[int, int]]: +) -> tuple[CCMeasuresDict, CCContour]: """Reconstruct surfaces and compute measures for a single slice for the corpus callosum. Parameters @@ -340,10 +338,8 @@ def recon_cc_surf_measure( ------- measures : CCMeasuresDict Dictionary containing measurements if successful. - contour_with_thickness : np.ndarray - Contour points with thickness information in fsavg_midslice_ras space, shape (3, N) for [x, y, thickness]. - endpoint_indices : pair of ints - Indices of the anterior and posterior endpoints on the contour. + contour : CCContour + The contour object containing points, thickness values, and endpoint indices. Raises ------ @@ -372,40 +368,10 @@ def recon_cc_surf_measure( slice_vox2ras=slice_lia_vox2midslice_ras, contour_smoothing=contour_smoothing, ) + levelpaths, thickness, midline_len, midline_equi, contour_with_thickness, endpoint_idxs, curvature = \ + _contour.create_levelpaths(num_thickness_points, inplace=True) + contour_as = _contour.points.T - endpoint_idxs = _contour.endpoint_idxs - # FIXME: could probably also use _contour.create_levelpaths here, but that does not currently return all values - # levelpaths, thickness = _contour.create_levelpaths(num_thickness_points) - - # FIXME: If we create CCContour objects here already (as we can), we should probably return that instead of the - # contour_with_thickness value (as the CCContour has all that information as well) - - # # find_contour_and_endpoints extracts the contour and finds ac and pc endpoints for shape analysis - # # contour is in IA voxel coordinates - # contour, endpoint_idxs = find_contour_and_endpoints( - # cc_mask_slice, - # ac_coords_vox, - # pc_coords_vox, - # (vox_size[1], vox_size[2]), - # return_coordinates=False, - # contour_smoothing=contour_smoothing, - # ) - # # contour_ras uses coordinates in the fsavg_midslice_ras coordinate system, now re-order/flip slice_ia - # # coordinates to fsavg_ras coordinates. - # #FIXME: double-check the sign of the z_offset (lr) here, currently starts positive for first slice - # offsets = np.asarray([-vox_size[0] * (slice_idx - segmentation.shape[0] // 2), 0, 0, 1]) - # affine = np.concatenate([slice_lia_vox2midslice_ras[:, :3], offsets[:, None]], axis=1) - # # convert to fsavg_ras coordinates (which are mid-slice-based) - # contour_as = (slice_lia_vox2midslice_ras @ np.append(contour, 1, axis=0))[1:3] - - contour_with_thickness: ContourThickness - # cc_thickness wants contour to be in midslice_ras coordinates, i.e. millimeter distances on the respective slice. - midline_len, thickness, curvature, midline_equi, levelpaths, contour_with_thickness, endpoint_idxs = \ - cc_thickness( - contour_as.T, - endpoint_idxs, - n_points=num_thickness_points, - ) # thickness values in contour_with_thickness is not equally sampled, different shape # to compute length of paths: diff between consecutive points (N-1, 2) => norm (N-1,) => sum (1,) thickness_profile = np.stack([np.sum(np.linalg.norm(np.diff(x[:, :2], axis=0), axis=1)) for x in levelpaths]) @@ -477,7 +443,7 @@ def recon_cc_surf_measure( "levelpaths": levelpaths, "slice_index": slice_idx } - return measures, contour_with_thickness, endpoint_idxs + return measures, _contour def test_right_of_line( @@ -518,6 +484,7 @@ def make_subdivision_mask( slice_shape: Shape2d, split_contours: ContourList, vox2ras: AffineMatrix4x4, + plot: bool = False, ) -> np.ndarray[Shape2d, np.dtype[np.int_]]: """Create a mask for subdividing the corpus callosum based on split contours. @@ -530,7 +497,8 @@ def make_subdivision_mask( Each contour is a tuple of x and y coordinates. vox2ras : AffineMatrix4x4 The vox2ras transformation matrix for the requested shape. - + plot : bool, default=False + Whether to plot the subdivision mask. Returns ------- np.ndarray @@ -583,6 +551,21 @@ def make_subdivision_mask( # All points to the right of this line belong to the next segment or beyond subdivision_mask[points_right_of_line] = label + + if plot: # interactive debug plot + import matplotlib + import matplotlib.pyplot as plt + curr_backend = matplotlib.get_backend() + plt.switch_backend("qtagg") + plt.figure(figsize=(10, 8)) + plt.imshow(subdivision_mask, cmap='tab10') + plt.colorbar(label='Subdivision') + plt.title('CC Subdivision Mask') + plt.xlabel('X') + plt.ylabel('Y') + plt.tight_layout() + plt.show() + plt.switch_backend(curr_backend) return subdivision_mask diff --git a/CorpusCallosum/shape/subsegment_contour.py b/CorpusCallosum/shape/subsegment_contour.py index b7da161b..c448f02f 100644 --- a/CorpusCallosum/shape/subsegment_contour.py +++ b/CorpusCallosum/shape/subsegment_contour.py @@ -136,17 +136,15 @@ def subsegment_midline_orthogonal( List of contour arrays for each subsegment. split_points : np.ndarray Array of shape (K, 2) containing points where the midline was split. + + Notes + ----- + Subsegments include all previous segments. This means subsegment contour two is the outline of the union + of subsegment one and subsegment two. """ # FIXME: Here and in other places, the order of dimensions is pretty inconsistent, for example: midline is (N, 2), # but contours are (2, N)... - # FIXME: why does this code return subsegments that include all previous segments? - # get points after midline length of splits - - # get vertex closest to midline end - - # FIXME: should this not always be the posterior endpoint index? Can we not standardize this even earlier, and then - # pull this into CCContour.from_mask_and_appc? midline_end_idx = np.argmin(np.linalg.norm(contour.T - midline[-1], axis=1)) # roll contour start to midline end contour = np.roll(contour, -midline_end_idx, axis=1) @@ -161,82 +159,41 @@ def subsegment_midline_orthogonal( edge_ortho_vectors = np.column_stack((-edge_directions[:, 1], edge_directions[:, 0])) edge_ortho_vectors = edge_ortho_vectors / np.linalg.norm(edge_ortho_vectors, axis=1)[:, None] - split_contours: ContourList = [contour] + # Calculate intersections between the perpendicular lines and the contour + # vectors from split points to all contour points + vectors = contour.T[None, :, :] - split_points[:, None, :] # (K, M, 2) - # FIXME: double loop should be vectorized, see commented code below for an initial attempt (not tested) - # also, finding intersections can be done more efficiently, instead of solving linear system for each segment - # we could just look for changes in the sign of cross products - # mid_to_contour: np.ndarray = contour[:, :, None] - split_points[:, None] - # mid_to_contour_length = np.linalg.norm(mid_to_contour, axis=0) - # mid_to_contour_norm = mid_to_contour / mid_to_contour_length[None] - # sin_theta = mid_to_contour_norm[0] * edge_ortho_vectors[1] - mid_to_contour_norm[1] * edge_ortho_vectors[0] - # index_on_contour, index_on_segment = np.where(sin_theta[:-1] * sin_theta[1:] < 0) - # sin_theta_x = sin_theta[index_on_segment] - # cos_theta_x = np.sqrt(1 - sin_theta_x * sin_theta_x) - # rot_mat = np.array([[cos_theta_x, -sin_theta_x], [sin_theta_x, cos_theta_x]]) - # # rotate mid_to_contour by sin_theta - # _mid_to_intersection = rot_mat.transpose(0, -1) @ mid_to_contour[:, None, (index_on_contour, index_on_segment)] - # mid_to_intersection = cos_theta_x * _mid_to_intersection[:, 0, :] - # intersection_points = split_points[:, index_on_segment] + mid_to_intersection - # mid_to_intersection_length = np.linalg.norm(mid_to_intersection, axis=0) - # - # - # for segment_idx in range(split_points.shape[1]): - # mask = index_on_segment == segment_idx - # if any(mask): - # # first_index and second_index are the indices on the contour - # # _first_index and _second_index are the indices on the intersection_points of this segment - # _first_index, _second_index, *_ = np.argsort(mid_to_intersection_length[mask]) - # first_index, second_index = index_on_contour[mask][[_first_index, _second_index]] - # if first_index > second_index: - # first_index, second_index = second_index, first_index - # _first_index, _second_index = _second_index, _first_index - # # connect first and second half - # start_to_cutoff = np.hstack( - # ( - # contour[:, :first_index + 1], # includes first_index - # intersection_points[:, mask][:, [_first_index, _second_index]], - # contour[:, second_index + 1 :], # excludes second_index - # ) - # ) - # split_contours.append(start_to_cutoff) + # Calculate cross product with ortho vectors to find side of the line (numerator of t) + # x*oy - y*ox + side = vectors[:, :, 0] * edge_ortho_vectors[:, None, 1] - vectors[:, :, 1] * edge_ortho_vectors[:, None, 0] - for pt_idx, split_point in enumerate(split_points): - intersections = [] - for i in range(contour.shape[1] - 1): - # get contour segment - segment_start = contour[:, i] - segment_end = contour[:, i + 1] - segment_vector = segment_end - segment_start + # Find where the side changes sign, indicating an intersection + sign_change = (side[:, :-1] * side[:, 1:]) <= 0 - # Check for intersection with the perpendicular line - matrix = np.array([segment_vector, -edge_ortho_vectors[pt_idx]]).T - if np.linalg.matrix_rank(matrix) < 2: - continue # Skip parallel lines + split_contours: ContourList = [contour] - # Solve for intersection - t, s = np.linalg.solve(matrix, split_point - segment_start) - if 0 <= t <= 1: - intersection_point = segment_start + t * segment_vector - intersections.append((i, intersection_point)) + for pt_idx, split_point in enumerate(split_points): + # Indices of contour segments that have sign changes for this split point + seg_indices = np.where(sign_change[pt_idx])[0] + + intersections = [] + for i in seg_indices: + s0 = side[pt_idx, i] + s1 = side[pt_idx, i + 1] + if s0 == s1: + t = 0.5 + else: + t = s0 / (s0 - s1) - # import matplotlib.pyplot as plt - # plt.figure() - # plt.plot(contour[0], contour[1], 'k-') - # plt.plot(midline[:,0], midline[:,1], 'k--') - # plt.plot(split_point[0], split_point[1], 'ro') + intersection_point = contour[:, i] + t * (contour[:, i + 1] - contour[:, i]) + intersections.append((i, intersection_point)) - # plt.plot([segment_start[0], segment_end[0]], [segment_start[1], segment_end[1]], 'bo', linewidth=2) - # plt.plot([split_point[0]-edge_ortho_vectors[pt_idx][0], split_point[0]+edge_ortho_vectors[pt_idx][0]], - # [split_point[1]-edge_ortho_vectors[pt_idx][1], - # split_point[1]+edge_ortho_vectors[pt_idx][1]], 'k-', linewidth=2) - # plt.show() # get the two intersections closest to split_point intersections.sort(key=lambda x: np.linalg.norm(x[1] - split_point)) # Create new contours by splitting at intersections - if intersections: + if len(intersections) >= 2: first_index, first_intersection = intersections[1] second_index, second_intersection = intersections[0] @@ -245,7 +202,6 @@ def subsegment_midline_orthogonal( first_intersection, second_intersection = second_intersection, first_intersection first_index += 1 - # second_index += 1 # connect first and second half start_to_cutoff = np.hstack( @@ -258,7 +214,7 @@ def subsegment_midline_orthogonal( ) split_contours.append(start_to_cutoff) else: - raise ValueError("No intersections found, this should not happen") + raise ValueError(f"No intersections found for split point {pt_idx}, this should not happen") # plot contour to first index, then split point, then contour to second index diff --git a/CorpusCallosum/utils/mapping_helpers.py b/CorpusCallosum/utils/mapping_helpers.py index eacdbd79..1d76cd55 100644 --- a/CorpusCallosum/utils/mapping_helpers.py +++ b/CorpusCallosum/utils/mapping_helpers.py @@ -124,15 +124,13 @@ def apply_transform_to_pt(pts: Vector3d | Polygon3dType, T: AffineMatrix4x4, inv np.ndarray Transformed point coordinates, shape (3,) or (3, N). """ - # FIXME: This function is very similar to nibabel.affines.apply_affine, reduce duplication. - # Differences: Here, pts dimensions are (3,) or (3, N), in apply_affine, they are (..., D-1) for DxD affines. if inv: T = np.linalg.inv(T) if pts.ndim == 1: - return (T @ np.hstack((pts, 1)))[:3] + return nib.affines.apply_affine(T, pts) else: - return (T @ np.concatenate([pts, np.ones((1, pts.shape[1]))]))[:3] + return nib.affines.apply_affine(T, pts.T).T def calc_mapping_to_standard_space( @@ -156,16 +154,16 @@ def calc_mapping_to_standard_space( Returns ------- - upright_volume : np.ndarray - Upright transformed volume. - standardized_volume : np.ndarray - Volume in standard space. - ac_coords_standardized : np.ndarray + standardized_to_orig_vox2vox : AffineMatrix4x4 + The vox2vox transformation matrix from standard space to original space. + ac_coords_standardized : Vector3d AC coordinates in standard space. - pc_coords_standardized : np.ndarray + pc_coords_standardized : Vector3d PC coordinates in standard space. - standardized_affine : np.ndarray - Affine matrix for standard space. + ac_coords_orig : Vector3d + AC coordinates in original space. + pc_coords_orig : Vector3d + PC coordinates in original space. """ image_center = np.array(orig.shape) / 2 @@ -209,7 +207,6 @@ def calc_mapping_to_standard_space( pc_coords_orig: Vector3d = apply_transform_to_pt( pc_coords_standardized, standardized_to_orig_vox2vox, inv=False, ) - #FIXME: incorrect docstring return standardized_to_orig_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig From 1bdb040738e33a3822fee9ceef4a240ef970df74 Mon Sep 17 00:00:00 2001 From: ClePol Date: Tue, 6 Jan 2026 11:51:18 +0100 Subject: [PATCH 63/68] adressed review comments --- .../segmentation_postprocessing.py | 54 ++++++++++--------- CorpusCallosum/shape/contour.py | 5 -- CorpusCallosum/shape/postprocessing.py | 5 +- CorpusCallosum/shape/subsegment_contour.py | 19 +------ CorpusCallosum/transforms/segmentation.py | 4 +- 5 files changed, 34 insertions(+), 53 deletions(-) diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index c8ee2d99..bcb1e2a0 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -160,7 +160,7 @@ def create_connection_line(point1: np.ndarray, point2: np.ndarray) -> list[tuple return line_points -def connect_nearby_components(seg_arr: np.ndarray, max_connection_distance: float = 3.0) -> np.ndarray: +def connect_nearby_components(seg_arr: np.ndarray, max_connection_distance: float = 3.0, plot: bool = False) -> np.ndarray: """Connect nearby disconnected components that should be connected. This function identifies disconnected components in the segmentation and creates @@ -172,7 +172,9 @@ def connect_nearby_components(seg_arr: np.ndarray, max_connection_distance: floa Input binary segmentation array. max_connection_distance : float, optional Maximum distance to connect components, by default 3.0. - + plot : bool, optional + Whether to plot the segmentation with connected components, by default False. + Returns ------- np.ndarray @@ -255,25 +257,30 @@ def connect_nearby_components(seg_arr: np.ndarray, max_connection_distance: floa logger.info(f"Created {connections_made} minimal connections between components") - # Plot components for visualization - # import matplotlib.pyplot as plt - # n_components = len(component_sizes) - # fig, axes = plt.subplots(1, n_components + 1, figsize=(5*(n_components + 1), 5)) - # if n_components == 1: - # axes = [axes] - # # Plot each component in a different color - # for i, (comp_id, comp_size) in enumerate(component_sizes): - # component_mask = labels_cc == comp_id - # axes[i].imshow(component_mask[component_mask.shape[0]//2], cmap='gray') - # axes[i].set_title(f'Component {comp_id}\nSize: {comp_size}') - # axes[i].axis('off') - - # # Plot the connected segmentation - # axes[-1].imshow(connected_seg[connected_seg.shape[0]//2], cmap='gray') - # axes[-1].set_title('Connected Segmentation') - # axes[-1].axis('off') - # plt.tight_layout() - # plt.show() + # Plot components for debugging + if plot: + import matplotlib + import matplotlib.pyplot as plt + curr_backend = matplotlib.get_backend() + plt.switch_backend("qtagg") + n_components = len(component_sizes) + fig, axes = plt.subplots(1, n_components + 1, figsize=(5*(n_components + 1), 5)) + if n_components == 1: + axes = [axes] + # Plot each component in a different color + for i, (comp_id, comp_size) in enumerate(component_sizes): + component_mask = labels_cc == comp_id + axes[i].imshow(component_mask[component_mask.shape[0]//2], cmap='gray') + axes[i].set_title(f'Component {comp_id}\nSize: {comp_size}') + axes[i].axis('off') + + # Plot the connected segmentation + axes[-1].imshow(connected_seg[connected_seg.shape[0]//2], cmap='gray') + axes[-1].set_title('Connected Segmentation') + axes[-1].axis('off') + plt.tight_layout() + plt.show() + plt.switch_backend(curr_backend) return connected_seg @@ -396,11 +403,6 @@ def extract_largest_connected_component( logger.info(f"Successfully reduced components from {original_components} to {connected_components} " "using minimal connections") mask = connected_seg - # else: - # logger.info("No connections made, falling back to dilation approach") - # # Fallback: use the original dilation approach - # struct1 = ndimage.generate_binary_structure(3, 3) - # mask = ndimage.binary_dilation(seg_arr, structure=struct1, iterations=1).astype(np.uint8) # Get connected components from the processed mask labels_cc = label(mask, connectivity=3, background=0) diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py index 61f09b88..fb64ba05 100644 --- a/CorpusCallosum/shape/contour.py +++ b/CorpusCallosum/shape/contour.py @@ -421,11 +421,6 @@ def plot_contour_colorfill( """ plot_values = plot_values[::-1] # make sure values are plotted left to right (anterior to posterior) - points, _ = make_mesh_from_contour(self.points, max_volume=0.5, min_angle=25, verbose=False) - - # make points 3D by adding zero - points = np.column_stack([points, np.zeros(len(points))]) - levelpaths, *_ = self.create_levelpaths(num_points=len(plot_values)-1, inplace=False) outside_contour = self.points.T diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index eb029fba..23661da6 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -239,7 +239,6 @@ def _gen_slice2slab_vox2vox(_slice_idx: int) -> AffineMatrix4x4: template_dir.mkdir(parents=True, exist_ok=True) logger.info("Saving template files (contours.txt, thickness_values.txt, " f"thickness_measurement_points.txt) to {template_dir}") - run = run for j in range(len(cc_contours)): io_futures.append(run(cc_contours[j].save_contour, template_dir / f"contour_{j}.txt")) io_futures.append(run(cc_contours[j].save_thickness_values, template_dir / f"thickness_values_{j}.txt")) @@ -415,7 +414,9 @@ def recon_cc_surf_measure( raise ValueError(f"Invalid subdivision method {subdivision_method}") total_area = np.sum(areas) - total_perimeter = np.sum(np.sqrt(np.sum((np.diff(contour_as, axis=0))**2, axis=1))) + # total_perimeter should include the edge from last to first point + contour_closed = np.concatenate([contour_as, contour_as[:, :1]], axis=1) + total_perimeter = np.sum(np.linalg.norm(np.diff(contour_closed, axis=1), axis=0)) circularity = 4 * np.pi * total_area / (total_perimeter**2) # Transform split contours back to original space diff --git a/CorpusCallosum/shape/subsegment_contour.py b/CorpusCallosum/shape/subsegment_contour.py index c448f02f..9cd9a87a 100644 --- a/CorpusCallosum/shape/subsegment_contour.py +++ b/CorpusCallosum/shape/subsegment_contour.py @@ -238,15 +238,6 @@ def subsegment_midline_orthogonal( if plot: extremes = [midline[0], midline[-1]] - plot_transform = None - if plot_transform is not None: - split_contours = [plot_transform(split_contour) for split_contour in split_contours] - contour = plot_transform(contour) - extremes = [plot_transform(extreme[:, None]) for extreme in extremes] - split_points = [plot_transform(split_point[:, None]) for split_point in split_points] - # split_points_vlines_start = plot_transform(split_points_vlines_start) - # split_points_vlines_end = plot_transform(split_points_vlines_end) - import matplotlib.pyplot as plt if ax is None: @@ -912,7 +903,6 @@ def get_primary_eigenvector(contour_ras: Polygon2dType) -> tuple[Vector2d, Vecto # Sort in descending order idx = eigenvalues.argsort()[::-1] - eigenvalues = eigenvalues[idx] eigenvectors = eigenvectors[:, idx] # make first eigenvector unit length @@ -920,13 +910,6 @@ def get_primary_eigenvector(contour_ras: Polygon2dType) -> tuple[Vector2d, Vecto pt0 = np.mean(contour_ras, axis=1) pt0 -= np.array([0, 5]) pt1 = pt0 + primary_eigenvector * 100 - # plot mask with eigentvector - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots(1,2,figsize=(10, 8)) - # ax[0].imshow(cc_mask, cmap='gray') - # # plot line between pt0 and pt1 - # ax[0].plot([pt0[0], pt1[0]], [pt0[1], pt1[1]], 'r-', linewidth=2) - # plt.show() - + return pt0, pt1 diff --git a/CorpusCallosum/transforms/segmentation.py b/CorpusCallosum/transforms/segmentation.py index 2b54b450..26943e25 100644 --- a/CorpusCallosum/transforms/segmentation.py +++ b/CorpusCallosum/transforms/segmentation.py @@ -88,8 +88,8 @@ def __call__(self, data: dict) -> dict: ac_pc_bottomleft = np.min(ac_pc, axis=0).astype(int) ac_pc_topright = np.max(ac_pc, axis=0).astype(int) - VoxPadType = np.ndarray[tuple[Literal[2]], np.dtype[np.int_]] - voxel_padding: VoxPadType = np.round(self.padding_mm / d["res"]).astype(int) + voxel_padding: np.ndarray[tuple[Literal[2]], np.dtype[np.int_]] = np.round( + self.padding_mm / d["res"]).astype(int) crop_left = ac_pc_bottomleft[1] - int(voxel_padding[0] * 1.5) + random_translate[0] crop_right = ac_pc_topright[1] + voxel_padding[0] // 2 + random_translate[0] From 3958a973ecc558e098c27796cbfa94ea0fa5f942 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Mon, 5 Jan 2026 17:49:05 +0100 Subject: [PATCH 64/68] Fix the types of CCMeasuresDict.thickness_profile and ContourList --- CorpusCallosum/utils/types.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CorpusCallosum/utils/types.py b/CorpusCallosum/utils/types.py index 26d0aee8..f39ace37 100644 --- a/CorpusCallosum/utils/types.py +++ b/CorpusCallosum/utils/types.py @@ -1,6 +1,6 @@ from typing import Literal, TypedDict -from numpy import dtype, ndarray +from numpy import dtype, ndarray, float_ from FastSurferCNN.utils import ScalarType @@ -20,7 +20,7 @@ Polygon3dType = ndarray[tuple[Literal[3], int], dtype[ScalarType]] Points2dType = ndarray[tuple[int, Literal[2]], dtype[ScalarType]] Points3dType = ndarray[tuple[int, Literal[3]], dtype[ScalarType]] -ContourList = list[Polygon2dType] +ContourList = list[type[Polygon2dType]] ContourThickness = ndarray[tuple[Literal[3], int], dtype[ScalarType]] SliceSelection = Literal["middle", "all"] | int SubdivisionMethod = Literal["shape", "vertical", "angular", "eigenvector"] @@ -63,7 +63,7 @@ class CCMeasuresDict(TypedDict): midline_length: float thickness: float curvature: float - thickness_profile: ndarray[tuple[int], dtype[float]] + thickness_profile: ndarray[tuple[int], dtype[float_]] total_area: float total_perimeter: float split_contours: ContourList From 3b930b56ecb6bcf62d6c1e4bed7009c2e0f8d44d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Tue, 6 Jan 2026 14:31:44 +0100 Subject: [PATCH 65/68] Fix endpoint extraction and rotation of offsets (for endpoint extraction) Empty line formatting --- CorpusCallosum/shape/contour.py | 13 ++----------- CorpusCallosum/shape/endpoint_heuristic.py | 11 ++++------- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py index fb64ba05..c03e005e 100644 --- a/CorpusCallosum/shape/contour.py +++ b/CorpusCallosum/shape/contour.py @@ -47,11 +47,6 @@ Self = TypeVar("Self", bound="CCContour") - - - - - # FIXME: Maybe CCContur should inherit from Polygon at a later date? class CCContour: """A class for representing and manipulating corpus callosum (CC) contours. @@ -81,8 +76,6 @@ class CCContour: >>> contour.save_thickness_measurement_points("thickness_measurement_points_0.txt") """ - - def __init__( self, points: Points2dType, @@ -189,7 +182,7 @@ def create_levelpaths( ---------- num_points : int Number of points for thickness estimation. - update_data : bool, default=True + inplace : bool, default=True Whether to update the contour points and thickness values in place. Returns @@ -214,7 +207,6 @@ def create_levelpaths( # changed; alternatively, make points and endpoint_idxs read_only (by creating getter-only properties) # and have all functions that change points or endpoints return a new CCContour object instead. - midline_len, thickness, curvature, midline_equi, levelpaths, contour_with_thickness, endpoint_idxs = \ cc_thickness( self.points, @@ -826,7 +818,6 @@ def from_mask_and_acpc( return cls(contour_ras[:, 1:], None, endpoint_idx, z_position=slice_vox2ras[0, 3]) - def calculate_volume(contours: list[CCContour], width: float = 5.0) -> float: """Calculate the volume of the corpus callosum. @@ -877,4 +868,4 @@ def calculate_volume(contours: list[CCContour], width: float = 5.0) -> float: effective_width = max(0.0, effective_end - effective_start) volume += areas[i] * effective_width - return volume \ No newline at end of file + return volume diff --git a/CorpusCallosum/shape/endpoint_heuristic.py b/CorpusCallosum/shape/endpoint_heuristic.py index e67dfe3b..01943c2e 100644 --- a/CorpusCallosum/shape/endpoint_heuristic.py +++ b/CorpusCallosum/shape/endpoint_heuristic.py @@ -154,7 +154,6 @@ def extract_cc_contour(cc_mask: Mask2d, contour_smoothing: int = 5) -> Polygon2d return contour - def find_cc_endpoints( contour: Points2dType, ac_2d: Vector2d, @@ -188,7 +187,7 @@ def find_cc_endpoints( # Calculate angle between AC-PC line and horizontal using numpy ac_pc_vector = pc_2d - ac_2d - horizontal_vector = np.array([0, -20]) + horizontal_vector = np.array([-20, 0]) # Calculate angle using dot product formula: cos(theta) = (a·b)/(|a||b|) dot_product = np.dot(ac_pc_vector, horizontal_vector) norms = np.linalg.norm(ac_pc_vector) * np.linalg.norm(horizontal_vector) @@ -196,11 +195,11 @@ def find_cc_endpoints( theta = np.sign(ac_pc_vector[0]) * np.arccos(dot_product / norms) rot_matrix_inv = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) - # move posterior commisure 10 mm inferior, 5 mm posterior - as_offset_pc = np.array([10, -5], dtype=float) + # move posterior commisure 5 mm posterior, 10 mm inferior + as_offset_pc = np.array([-5, -10], dtype=float) posterior_anchor_2d = pc_2d.astype(float) + rot_matrix_inv @ as_offset_pc # move anterior commisure 5 mm anterior - as_offset_ac = np.array([0, 5], dtype=float) + as_offset_ac = np.array([5, 0], dtype=float) anterior_anchor_2d = ac_2d.astype(float) + rot_matrix_inv @ as_offset_ac # Find the endpoints of the CC shape relative to AC and PC coordinates @@ -209,7 +208,6 @@ def find_cc_endpoints( # find point in contour closest to PC pc_startpoint_idx = np.argmin(np.linalg.norm(contour - posterior_anchor_2d[:, None], axis=0)) - if plot: # interactive debug plot of contour, ac, pc and endpoints import matplotlib import matplotlib.pyplot as plt @@ -232,5 +230,4 @@ def find_cc_endpoints( plt.show() plt.switch_backend(curr_backend) - return ac_startpoint_idx, pc_startpoint_idx From 3035c7e910b1d6def25b6f3da19f08e9fd3c3d65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Tue, 6 Jan 2026 14:58:28 +0100 Subject: [PATCH 66/68] Fix doc and style Fix some typing --- .../segmentation_postprocessing.py | 40 +++++++++++-------- CorpusCallosum/shape/contour.py | 4 +- CorpusCallosum/shape/postprocessing.py | 5 ++- CorpusCallosum/utils/types.py | 2 +- 4 files changed, 29 insertions(+), 22 deletions(-) diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index bcb1e2a0..ac10490c 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -11,21 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import TypeVar import numpy as np -from numpy import typing as npt from scipy import ndimage from scipy.spatial.distance import cdist from skimage.measure import label +from torchgen.model import ScalarType import FastSurferCNN.utils.logging as logging from CorpusCallosum.data.constants import CC_LABEL -from FastSurferCNN.utils import Mask3d, Shape3d +from CorpusCallosum.utils.types import Points3dType +from FastSurferCNN.utils import Mask3d, Shape3d, ShapeType, Vector3d logger = logging.get_logger(__name__) +ArrayType = TypeVar('ArrayType', bound=np.ndarray) -def find_component_boundaries(labels_arr: npt.NDArray[int], component_id: int) -> npt.NDArray[int]: + +def find_component_boundaries(labels_arr: np.ndarray[ShapeType, np.dtype[ScalarType]], component_id: int) \ + -> np.ndarray[ShapeType, np.dtype[np.integer]]: """Find boundary voxels of a connected component. Parameters @@ -61,10 +66,10 @@ def find_component_boundaries(labels_arr: npt.NDArray[int], component_id: int) - def find_minimal_connection_path( - boundary_coords1: np.ndarray, - boundary_coords2: np.ndarray, + boundary_coords1: Points3dType, + boundary_coords2: Points3dType, max_distance: float = 3.0 -) -> tuple[np.ndarray, np.ndarray] | None: +) -> tuple[Vector3d, Vector3d] | None: """Find the minimal connection path between two component boundaries. Parameters @@ -107,7 +112,7 @@ def find_minimal_connection_path( return None -def create_connection_line(point1: np.ndarray, point2: np.ndarray) -> list[tuple[int, int, int]]: +def create_connection_line(point1: Vector3d, point2: Vector3d) -> list[tuple[int, int, int]]: """Create a line of voxels connecting two points. Uses a simplified 3D line algorithm to create a sequence of voxels @@ -122,7 +127,7 @@ def create_connection_line(point1: np.ndarray, point2: np.ndarray) -> list[tuple Returns ------- - list[tuple[int, int, int]] + list of int triplets List of (x, y, z) coordinates forming the connection line. Notes @@ -133,7 +138,7 @@ def create_connection_line(point1: np.ndarray, point2: np.ndarray) -> list[tuple x1, y1, z1 = map(int, point1) x2, y2, z2 = map(int, point2) - line_points = [] + line_points: list[tuple[int, int, int]] = [] # Calculate the number of steps needed dx = abs(x2 - x1) @@ -160,7 +165,8 @@ def create_connection_line(point1: np.ndarray, point2: np.ndarray) -> list[tuple return line_points -def connect_nearby_components(seg_arr: np.ndarray, max_connection_distance: float = 3.0, plot: bool = False) -> np.ndarray: +def connect_nearby_components(seg_arr: ArrayType, max_connection_distance: float = 3.0, plot: bool = False) \ + -> ArrayType: """Connect nearby disconnected components that should be connected. This function identifies disconnected components in the segmentation and creates @@ -170,10 +176,10 @@ def connect_nearby_components(seg_arr: np.ndarray, max_connection_distance: floa ---------- seg_arr : np.ndarray Input binary segmentation array. - max_connection_distance : float, optional - Maximum distance to connect components, by default 3.0. - plot : bool, optional - Whether to plot the segmentation with connected components, by default False. + max_connection_distance : float, default=3.0 + Maximum distance to connect components. + plot : bool, default=False + Whether to plot the segmentation with connected components. Returns ------- @@ -235,7 +241,7 @@ def connect_nearby_components(seg_arr: np.ndarray, max_connection_distance: floa f"Distance: {distance:.2f} voxels") # Create connection line - connection_line = create_connection_line(point1, point2) + connection_line: list[tuple[int, int, int]] = create_connection_line(point1, point2) # Add connection voxels to the segmentation # Use the same label as the original segmentation at the connection points @@ -287,7 +293,7 @@ def connect_nearby_components(seg_arr: np.ndarray, max_connection_distance: floa def get_cc_volume_voxel( desired_width_mm: int, - cc_mask: np.ndarray, + cc_mask: Mask3d, voxel_size: tuple[float, float, float], ) -> float: """Calculate the volume of the corpus callosum in cubic millimeters. @@ -368,7 +374,7 @@ def extract_largest_connected_component( seg_arr: Mask3d, max_connection_distance: float = 3.0, ) -> Mask3d: - """Get largest connected component from a binary segmentation array. + """Get the largest connected component from a binary segmentation array. Parameters ---------- diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py index c03e005e..9bc28ade 100644 --- a/CorpusCallosum/shape/contour.py +++ b/CorpusCallosum/shape/contour.py @@ -38,7 +38,7 @@ import FastSurferCNN.utils.logging as logging from CorpusCallosum.shape.endpoint_heuristic import find_cc_endpoints, smooth_contour -from CorpusCallosum.shape.thickness import cc_thickness, make_mesh_from_contour +from CorpusCallosum.shape.thickness import cc_thickness from CorpusCallosum.utils.types import Points2dType from FastSurferCNN.utils import AffineMatrix4x4, Mask2d, Vector2d @@ -811,7 +811,7 @@ def from_mask_and_acpc( polygon.resample(700, inplace=True) contour_ras = apply_affine(slice_vox2ras, polygon.points) - ac_pc_3d = np.concatenate([[[0], [0]], np.stack([ac_2d, pc_2d], axis=0)], axis=1) # (2, 3) + ac_pc_3d = np.concatenate([np.zeros((2, 1), like=ac_2d), np.stack([ac_2d, pc_2d], axis=0)], axis=1) # (2, 3) ac_ras, pc_ras = apply_affine(slice_vox2ras, ac_pc_3d) endpoint_idx = find_cc_endpoints(contour_ras[:, 1:].T, ac_ras[1:], pc_ras[1:]) diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index 23661da6..8ea8a781 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -141,8 +141,8 @@ def recon_cc_surf_measures_multi( List of background IO processes. list of CCContour List of CC contours. - CCMesh - The CC mesh. (None if no mesh was created) + CCMesh, None + The CC mesh or None if no mesh was created. """ slice_cc_measures: list[CCMeasuresDict] = [] io_futures = [] @@ -500,6 +500,7 @@ def make_subdivision_mask( The vox2ras transformation matrix for the requested shape. plot : bool, default=False Whether to plot the subdivision mask. + Returns ------- np.ndarray diff --git a/CorpusCallosum/utils/types.py b/CorpusCallosum/utils/types.py index f39ace37..78ad3b81 100644 --- a/CorpusCallosum/utils/types.py +++ b/CorpusCallosum/utils/types.py @@ -1,6 +1,6 @@ from typing import Literal, TypedDict -from numpy import dtype, ndarray, float_ +from numpy import dtype, float_, ndarray from FastSurferCNN.utils import ScalarType From 16dbb8fdbdee1f3400418fce9edb7cfba9a838de Mon Sep 17 00:00:00 2001 From: ClePol Date: Tue, 6 Jan 2026 15:26:14 +0100 Subject: [PATCH 67/68] fixed stats (subseg-area, ordering, etc.), improved visualization --- CorpusCallosum/cc_visualization.py | 3 + .../segmentation_postprocessing.py | 52 ++++----- CorpusCallosum/shape/contour.py | 42 +++++--- CorpusCallosum/shape/curvature.py | 6 +- CorpusCallosum/shape/endpoint_heuristic.py | 4 - CorpusCallosum/shape/mesh.py | 17 ++- CorpusCallosum/shape/metrics.py | 11 +- CorpusCallosum/shape/postprocessing.py | 9 +- CorpusCallosum/shape/subsegment_contour.py | 102 ++++++++++++------ CorpusCallosum/shape/thickness.py | 10 +- CorpusCallosum/utils/visualization.py | 4 +- doc/overview/modules/CC.md | 18 ++-- doc/scripts/cc_visualization.rst | 8 +- 13 files changed, 168 insertions(+), 118 deletions(-) diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index 9027093d..714f44dd 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -160,6 +160,9 @@ def load_contours_from_template_dir( else: current_contour = CCContour.from_contour_file(contour_file, thickness_file, z_position=z_position) + if smoothing_window > 0: + current_contour.smooth_contour(window_size=smoothing_window) + current_contour.fill_thickness_values() contours.append(current_contour) diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py index ac10490c..a0b6e272 100644 --- a/CorpusCallosum/segmentation/segmentation_postprocessing.py +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -325,45 +325,42 @@ def get_cc_volume_voxel( Notes ----- - The function assumes LIA orientation where: - - x dimension corresponds to Left/Right - - y dimension corresponds to Inferior/Superior - - z dimension corresponds to Anterior/Posterior + The function assumes LIA orientation """ - assert cc_mask.shape[0] % 2 == 1, "CC mask must have odd number of voxels in x dimension" + # Get the bounding box of the CC mask in x dimension + any_cc = np.any(cc_mask, axis=(1, 2)) + if not np.any(any_cc): + return 0.0 + + first_x = np.argmax(any_cc) + last_x = len(any_cc) - 1 - np.argmax(any_cc[::-1]) + + # Crop mask to its extent in x + cropped_mask = cc_mask[first_x : last_x + 1] + width_vox = cropped_mask.shape[0] + assert width_vox % 2 == 1, f"CC mask must have odd number of voxels in x dimension, but has {width_vox}" # Calculate voxel volume voxel_volume: float = np.prod(voxel_size, dtype=float) voxel_width: float = voxel_size[0] - # Get width of CC mask in voxels by finding the extent in x dimension - width_vox = np.sum(np.any(cc_mask, axis=(1,2))) - # we are in LIA, so 0 is L/R resolution width_mm = width_vox * voxel_width if width_mm == desired_width_mm: - return np.sum(cc_mask) * voxel_volume + return np.sum(cropped_mask) * voxel_volume elif width_mm > desired_width_mm: # remainder on the left/right side of the CC mask desired_width_vox = desired_width_mm / voxel_width - # The number of full voxels in the center is (cc_mask.shape[0] - 2) + + # The number of full voxels in the center is (width_vox - 2) # The remaining width must be covered by the two edge voxels. - fraction_of_voxel_at_edge = (desired_width_vox - (cc_mask.shape[0] - 2)) / 2 - - if fraction_of_voxel_at_edge > 0: - # make sure the assumentation is correct that the CC mask has an odd number of voxels - # and the leftmost and rightmost voxels are the edges at the desired width - cc_width_vox = int(np.floor(desired_width_vox) + 1) - cc_width_vox = cc_width_vox + 1 if cc_width_vox % 2 == 0 else cc_width_vox - - assert cc_mask.shape[0] == cc_width_vox, (f"CC mask should have {cc_width_vox} voxels, " - f"but has {cc_mask.shape[0]}") + fraction_of_voxel_at_edge = (desired_width_vox - (width_vox - 2)) / 2 - left_partial_volume = np.sum(cc_mask[0]) * voxel_volume * fraction_of_voxel_at_edge - right_partial_volume = np.sum(cc_mask[-1]) * voxel_volume * fraction_of_voxel_at_edge - center_volume = np.sum(cc_mask[1:-1]) * voxel_volume + left_partial_volume = np.sum(cropped_mask[0]) * voxel_volume * fraction_of_voxel_at_edge + right_partial_volume = np.sum(cropped_mask[-1]) * voxel_volume * fraction_of_voxel_at_edge + center_volume = np.sum(cropped_mask[1:-1]) * voxel_volume return left_partial_volume + right_partial_volume + center_volume else: raise ValueError(f"Width of CC segmentation is smaller than desired width: {width_mm} < {desired_width_mm}") @@ -445,14 +442,7 @@ def clean_cc_segmentation( Cleaned segmentation array with only the largest connected component of CC and fornix. mask : npt.NDArray[bool] Binary mask of the largest connected component. - - Notes - ----- - The function: - 1. Isolates the CC (label 192) - 2. Attempts to connect nearby disconnected components - 3. Adds the fornix (label 250) - 4. Removes non-connected components from the combined CC and fornix + """ from functools import partial diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py index 9bc28ade..7044ef1e 100644 --- a/CorpusCallosum/shape/contour.py +++ b/CorpusCallosum/shape/contour.py @@ -73,7 +73,6 @@ class CCContour: >>> contour.smooth_contour(window_size=5) >>> contour.save_contour("contour_0.txt") >>> contour.save_thickness_values("thickness_values_0.txt") - >>> contour.save_thickness_measurement_points("thickness_measurement_points_0.txt") """ def __init__( @@ -148,12 +147,17 @@ def smooth_contour(self, window_size: int = 5) -> None: ----- Uses smooth_contour from cc_endpoint_heuristic module. """ - self.points = np.array([smooth_contour(*self.points.T, window_size=window_size)]).T + self.points = np.array(smooth_contour(*self.points.T, window_size=window_size)).T def copy(self) -> "CCContour": """Copy the contour. """ - return CCContour(self.points.copy(), self.thickness_values.copy(), self.endpoint_idxs, self.z_position) + return CCContour( + self.points.copy(), + self.thickness_values.copy() if self.thickness_values is not None else None, + self.endpoint_idxs, + self.z_position + ) def get_contour_edge_lengths(self) -> np.ndarray: """Get the lengths of the edges of a contour. @@ -166,9 +170,10 @@ def get_contour_edge_lengths(self) -> np.ndarray: Notes ----- Edge lengths are calculated as Euclidean distances between consecutive points - in the contour. + in the contour, including the edge closing the loop between the last and + first point. """ - edges = np.diff(self.points, axis=0) + edges = np.roll(self.points, -1, axis=0) - self.points return np.sqrt(np.sum(edges**2, axis=1)) def create_levelpaths( @@ -291,6 +296,9 @@ def fill_thickness_values(self) -> None: return # For each point with unknown thickness + total_length = np.sum(edge_lengths) + cumulative_lengths = np.concatenate(([0], np.cumsum(edge_lengths))) + for j in range(len(thickness)): if not np.isnan(thickness[j]): continue @@ -299,10 +307,13 @@ def fill_thickness_values(self) -> None: distances = np.zeros(len(known_idx)) for k, idx in enumerate(known_idx): # Calculate distance along contour by summing edge lengths + # in both directions and taking the minimum if idx > j: - distances[k] = np.sum(edge_lengths[j:idx]) + dist_forward = cumulative_lengths[idx] - cumulative_lengths[j] else: - distances[k] = np.sum(edge_lengths[idx:j]) + dist_forward = cumulative_lengths[j] - cumulative_lengths[idx] + + distances[k] = min(dist_forward, total_length - dist_forward) # Get indices of two closest points closest_indices = known_idx[np.argsort(distances)[:2]] @@ -328,11 +339,13 @@ def smooth_thickness_values(self, iterations: int = 1) -> None: Notes ----- Applies Gaussian smoothing with sigma=5 to thickness values - for each slice that has measurements. + along the contour. """ - for i in range(len(self.thickness_values)): - if self.thickness_values[i] is not None: - self.thickness_values[i] = gaussian_filter1d(self.thickness_values[i], sigma=5) + if self.thickness_values is not None: + # Handle NaN values by interpolating if necessary or just smoothing the non-NaN parts + # Here we assume they might have been filled already by fill_thickness_values + for _ in range(iterations): + self.thickness_values = gaussian_filter1d(self.thickness_values, sigma=5, mode="wrap") def plot_contour(self, output_path: str | None = None) -> None: """Plot a single contour with thickness values. @@ -358,13 +371,14 @@ def plot_contour(self, output_path: str | None = None) -> None: # Plot points with colors based on thickness gray_points = np.isnan(self.thickness_values) if np.any(gray_points): - plt.plot(*self.points[gray_points, :2].T, "o", color="gray", markersize=1) + plt.scatter(self.points[gray_points, 0], self.points[gray_points, 1], color="gray", s=1) if not np.all(gray_points): not_gray = np.logical_not(gray_points) - color_values = plt.cm.YlOrRd(self.thickness_values[not_gray] / np.nanmax(self.thickness_values[not_gray])) # Map thickness to color from red to yellow - plt.plot(*self.points[~gray_points, :2].T, "o", color=color_values, markersize=1) + norm_thickness = self.thickness_values[not_gray] / np.nanmax(self.thickness_values[not_gray]) + color_values = plt.cm.YlOrRd(norm_thickness) + plt.scatter(self.points[not_gray, 0], self.points[not_gray, 1], c=color_values, s=1) # Connect points with lines plt.plot(self.points[:, 0], self.points[:, 1], "-", color="black", alpha=0.3, label="Contour") diff --git a/CorpusCallosum/shape/curvature.py b/CorpusCallosum/shape/curvature.py index daaeebac..214ad143 100644 --- a/CorpusCallosum/shape/curvature.py +++ b/CorpusCallosum/shape/curvature.py @@ -41,7 +41,7 @@ def compute_curvature(path: Points2dType) -> np.ndarray[tuple[int], np.dtype[np. def compute_mean_curvature(path: Points2dType) -> float: - """Compute mean curvature of a path. + """Compute mean absolute curvature of a path in degrees. Parameters ---------- @@ -51,12 +51,12 @@ def compute_mean_curvature(path: Points2dType) -> float: Returns ------- float - Mean curvature of the path. + Mean absolute curvature of the path in degrees. """ curvature = compute_curvature(path) if len(curvature) == 0: return 0.0 - return np.abs(np.degrees(np.mean(curvature))).item() / len(curvature) + return np.mean(np.abs(np.degrees(curvature))).item() def calculate_curvature_metrics( diff --git a/CorpusCallosum/shape/endpoint_heuristic.py b/CorpusCallosum/shape/endpoint_heuristic.py index 01943c2e..09e8e128 100644 --- a/CorpusCallosum/shape/endpoint_heuristic.py +++ b/CorpusCallosum/shape/endpoint_heuristic.py @@ -59,10 +59,6 @@ def smooth_contour(x: np.ndarray, y: np.ndarray, window_size: int) -> tuple[np.n x_smoothed[i] = np.mean(x_padded[i : i + window_size]) y_smoothed[i] = np.mean(y_padded[i : i + window_size]) - # remove padding - x_smoothed = x_smoothed[window_size // 2:-window_size // 2] - y_smoothed = y_smoothed[window_size // 2:-window_size // 2] - return x_smoothed, y_smoothed diff --git a/CorpusCallosum/shape/mesh.py b/CorpusCallosum/shape/mesh.py index 0f1859db..f3e5f22d 100644 --- a/CorpusCallosum/shape/mesh.py +++ b/CorpusCallosum/shape/mesh.py @@ -556,7 +556,9 @@ def snap_cc_picture( fssurf_file = Path(fssurf_file) else: fssurf_file = tempfile.NamedTemporaryFile(suffix=".fssurf", delete=True).name - self.write_fssurf(fssurf_file, image=str(ref_image) if isinstance(ref_image, Path) else ref_image) + + ref_image_arg = str(ref_image) if isinstance(ref_image, (Path, str)) else ref_image + self.write_fssurf(fssurf_file, image=ref_image_arg) if overlay_file: overlay_file = Path(overlay_file) @@ -740,9 +742,10 @@ def from_contours( # Calculate z coordinates for each slice # z_coordinates = (np.arange(len(contours)) - len(contours) // 2) * contours[0].resolution + lr_center - # Build vertices list with z-coordinates + # vertices list with z-coordinates and collect thickness values vertices = [] faces = [] + vertex_values_list = [] vertex_start_indices = [] # Track starting index for each contour current_index = 0 previous_contour: CCContour | None = None @@ -750,6 +753,8 @@ def from_contours( for contour in contours: vertex_start_indices.append(current_index) vertices.append(np.hstack([np.full((len(contour.points), 1), contour.z_position), contour.points])) + if contour.thickness_values is not None: + vertex_values_list.append(contour.thickness_values) # Check if there's a next valid contour to connect to if previous_contour is not None: @@ -758,11 +763,15 @@ def from_contours( faces_between = make_triangles_between_contours(previous_contour.points, contour.points) faces.append(faces_between + current_index) - current_index += len(contour.points) + current_index += len(previous_contour.points) previous_contour = contour - vertex_values = np.concatenate([contour.thickness_values for contour in contours]) + vertex_values = None + if len(vertex_values_list) == len(contours): + vertex_values = np.concatenate(vertex_values_list) + elif len(vertex_values_list) > 0: + logger.warning("Some contours have thickness values while others don't; skipping thickness overlay") if smooth > 0: tmp_mesh = CCMesh(vertices, faces, vertex_values=vertex_values) diff --git a/CorpusCallosum/shape/metrics.py b/CorpusCallosum/shape/metrics.py index 921c819f..4003b7b7 100644 --- a/CorpusCallosum/shape/metrics.py +++ b/CorpusCallosum/shape/metrics.py @@ -164,13 +164,14 @@ def calculate_cc_index(cc_contour: np.ndarray, plot: bool = False) -> float: The CC index, which is the sum of thicknesses at three measurement points divided by AP length. """ # Get anterior and posterior points (extremes along x-axis) - anterior_idx = np.argmin(cc_contour[0]) # Leftmost point - posterior_idx = np.argmax(cc_contour[0]) # Rightmost point + # In ACPC space, X is Anterior-Posterior direction, where Anterior is positive + posterior_idx = np.argmin(cc_contour[0]) # Minimum X is Posterior + anterior_idx = np.argmax(cc_contour[0]) # Maximum X is Anterior anterior_pt = cc_contour[:, anterior_idx] posterior_pt = cc_contour[:, posterior_idx] - # AP line vector and properties + # AP line vector from anterior to posterior ap_vector = posterior_pt - anterior_pt ap_length = np.linalg.norm(ap_vector) ap_unit = ap_vector / ap_length @@ -213,7 +214,10 @@ def calculate_cc_index(cc_contour: np.ndarray, plot: bool = False) -> float: cc_index = (anterior_thickness + posterior_thickness + middle_thickness) / ap_distance if plot: + import matplotlib import matplotlib.pyplot as plt + curr_backend = matplotlib.get_backend() + plt.switch_backend("qtagg") fig, ax = plt.subplots(figsize=(8, 6)) plot_cc_index_calculation( @@ -227,6 +231,7 @@ def calculate_cc_index(cc_contour: np.ndarray, plot: bool = False) -> float: ) ax.legend() plt.show() + plt.switch_backend(curr_backend) return cc_index diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py index 8ea8a781..cfdb974a 100644 --- a/CorpusCallosum/shape/postprocessing.py +++ b/CorpusCallosum/shape/postprocessing.py @@ -99,7 +99,7 @@ def recon_cc_surf_measures_multi( subdivision_method: SubdivisionMethod, contour_smoothing: int, subject_dir: SubjectDirectory, -) -> tuple[list[CCMeasuresDict], list[concurrent.futures.Future]]: +) -> tuple[list[CCMeasuresDict], list[concurrent.futures.Future], list[CCContour], CCMesh | None]: """Surface reconstruction and metrics computation of corpus callosum slices based on selection mode. Parameters @@ -412,6 +412,9 @@ def recon_cc_surf_measure( split_contours = [rotate_back_eigen(split_contour) for split_contour in split_contours] else: raise ValueError(f"Invalid subdivision method {subdivision_method}") + + # order areas anterior to posterior + areas = areas[::-1] total_area = np.sum(areas) # total_perimeter should include the edge from last to first point @@ -577,7 +580,7 @@ def check_area_changes(contours: list[np.ndarray], threshold: float = 0.3) -> bo Parameters ---------- contours : list[np.ndarray] - List of contours. + List of contours (2, N). threshold : float, default=0.3 Threshold for relative change. @@ -587,7 +590,7 @@ def check_area_changes(contours: list[np.ndarray], threshold: float = 0.3) -> bo True if no large area changes are detected, False otherwise. """ - areas = np.asarray([np.sum(np.sqrt(np.sum((np.diff(contour, axis=0))**2, axis=1))) for contour in contours]) + areas = np.asarray([np.abs(np.trapz(c[1], c[0])) for c in contours]) assert len(areas) > 1, "At least two areas are required to check for area changes" diff --git a/CorpusCallosum/shape/subsegment_contour.py b/CorpusCallosum/shape/subsegment_contour.py index 9cd9a87a..33fbd5a3 100644 --- a/CorpusCallosum/shape/subsegment_contour.py +++ b/CorpusCallosum/shape/subsegment_contour.py @@ -43,7 +43,8 @@ def minimum_bounding_rectangle(points: Points2dType) -> np.ndarray[tuple[Literal hull_points = points[ConvexHull(points).vertices] # calculate edge angles - edges = hull_points[1:] - hull_points[:-1] + # including the edge that closes the loop from last to first point + edges = np.vstack([hull_points[1:] - hull_points[:-1], hull_points[0] - hull_points[-1]]) angles = np.arctan2(edges[:, 1], edges[:, 0]) @@ -89,18 +90,27 @@ def calc_subsegment_areas(split_contours: ContourList) -> np.ndarray[tuple[int], Parameters ---------- split_contours : list of np.ndarray - List of contour arrays, each of shape (2, N). + List of contour arrays, each of shape (2, N). The list should contain + a set of nested contours (cumulative subsegments) and the full contour. Returns ------- subsegment_areas : array of floats - Array containing the area of each subsegment. + Array containing the area of each incremental subsegment. """ # calculate area of each split contour using the shoelace formula - areas = np.abs([np.trapz(split_contour[1], split_contour[0]) for split_contour in split_contours]) - if len(areas) == 1: - return np.asarray(areas[0]) - return np.ediff1d(np.asarray(areas)[::-1], to_end=areas[-1]) + # we use the absolute value because the orientation of the contour may vary + areas_cum = np.abs([np.trapz(c[1], c[0]) for c in split_contours]) + if len(areas_cum) == 1: + return np.asarray(areas_cum[0]) + + # Sort areas to ensure they are in increasing order of size + # This handles both cases where subsegments were provided in increasing or decreasing order + # The set of areas represents a sequence of nested shapes. + sorted_areas = np.sort(areas_cum) + + # Calculate the incremental pieces by taking differences between consecutive sizes + return np.diff(sorted_areas, prepend=0) def subsegment_midline_orthogonal( @@ -149,8 +159,17 @@ def subsegment_midline_orthogonal( # roll contour start to midline end contour = np.roll(contour, -midline_end_idx, axis=1) - edge_idx, edge_frac = np.divmod(len(midline) * np.array(area_weights), 1) + # Calculate edge indices and fractions for splitting the midline + # We use len(midline) - 1 because we are looking for intervals between points + edge_idx_float = (len(midline) - 1) * np.array(area_weights) + edge_idx, edge_frac = np.divmod(edge_idx_float, 1) edge_idx = edge_idx.astype(int) + + # Handle cases where area_weights might reach 1.0, which would lead to an out-of-bounds access + at_end = edge_idx >= len(midline) - 1 + edge_idx[at_end] = len(midline) - 2 + edge_frac[at_end] = 1.0 + split_points = midline[edge_idx] + (midline[edge_idx + 1] - midline[edge_idx]) * edge_frac[:, None] # get edge for each split point @@ -168,7 +187,9 @@ def subsegment_midline_orthogonal( side = vectors[:, :, 0] * edge_ortho_vectors[:, None, 1] - vectors[:, :, 1] * edge_ortho_vectors[:, None, 0] # Find where the side changes sign, indicating an intersection - sign_change = (side[:, :-1] * side[:, 1:]) <= 0 + # Handle wrap-around by appending the first side value to the end + side_wrapped = np.hstack([side, side[:, 0:1]]) + sign_change = (side_wrapped[:, :-1] * side_wrapped[:, 1:]) <= 0 split_contours: ContourList = [contour] @@ -177,15 +198,19 @@ def subsegment_midline_orthogonal( seg_indices = np.where(sign_change[pt_idx])[0] intersections = [] + num_points = contour.shape[1] for i in seg_indices: s0 = side[pt_idx, i] - s1 = side[pt_idx, i + 1] + s1 = side[pt_idx, (i + 1) % num_points] if s0 == s1: t = 0.5 else: t = s0 / (s0 - s1) - intersection_point = contour[:, i] + t * (contour[:, i + 1] - contour[:, i]) + # intersection point on the segment + p0 = contour[:, i] + p1 = contour[:, (i + 1) % num_points] + intersection_point = p0 + t * (p1 - p0) intersections.append((i, intersection_point)) @@ -442,11 +467,12 @@ def hampel_subdivide_contour(contour: Polygon2dType, num_rays: int, plot: bool = # Subdivision logic split_contours: ContourList = [] + num_points = contour.shape[1] for ray_vector in ray_vectors.T: intersections = [] - for i in range(contour.shape[1] - 1): + for i in range(num_points): segment_start = contour[:, i] - segment_end = contour[:, i + 1] + segment_end = contour[:, (i + 1) % num_points] segment_vector = segment_end - segment_start # Check for intersection with the ray @@ -455,16 +481,22 @@ def hampel_subdivide_contour(contour: Polygon2dType, num_rays: int, plot: bool = continue # Skip parallel lines # Solve for intersection - t, s = np.linalg.solve(matrix, midpoint_lower_edge - segment_start) - if 0 <= t <= 1: - intersection_point = segment_start + t * segment_vector - intersections.append((i, intersection_point)) - - # Sort intersections by their position along the contour - intersections.sort() + # matrix * [t, s]^T = midpoint_lower_edge - segment_start + try: + t, s = np.linalg.solve(matrix, midpoint_lower_edge - segment_start) + if 0 <= t < 1: # Use half-open interval to avoid double-counting vertices + intersection_point = segment_start + t * segment_vector + intersections.append((i, intersection_point)) + except np.linalg.LinAlgError: + continue # Create new contours by splitting at intersections - if intersections: + if len(intersections) >= 2: + # Sort intersections by their position along the contour (index) + intersections.sort(key=lambda x: x[0]) + + # For HAMPEL (radial rays), we usually expect two intersections. + # If there are more, we pick the first and last along the contour. first_index, first_intersection = intersections[0] second_index, second_intersection = intersections[-1] @@ -482,10 +514,7 @@ def hampel_subdivide_contour(contour: Polygon2dType, num_rays: int, plot: bool = else: raise ValueError("No intersections found, this should not happen") - split_contours.append(contour) - split_contours = split_contours[::-1] - - # split_contours = split_contours[::-1] + split_contours = [contour] + split_contours # Plotting logic if plot: @@ -573,20 +602,20 @@ def subdivide_contour( """ # Find the extreme points in the x-direction - min_x_index = np.argmax(contour[0]) + min_x_index = np.argmin(contour[0]) contour = np.roll(contour, -min_x_index, axis=1) min_x_index = 0 - max_x_index = np.argmin(contour[0]) + max_x_index = np.argmax(contour[0]) if oriented: contour_x_sorted = np.sort(contour[0]) min_x = contour_x_sorted[0] max_x = contour_x_sorted[-1] - extremes = (np.array([max_x, 0]), np.array([min_x, 0])) + extremes = (np.array([min_x, 0]), np.array([max_x, 0])) if hline_anchor is not None: - extremes = (np.array([max_x, hline_anchor[1]]), np.array([min_x, hline_anchor[1]])) + extremes = (np.array([min_x, hline_anchor[1]]), np.array([max_x, hline_anchor[1]])) else: extremes = (contour[:, min_x_index].copy(), contour[:, max_x_index].copy()) # Calculate the line between the extreme points @@ -676,16 +705,19 @@ def subdivide_contour( first_intersection, second_intersection = second_intersection, first_intersection first_index += 1 - # second_index += 1 - # start_to_cutoff = np.hstack((contour[:, :first_index], first_intersection[:, None], - # second_intersection[:, None], contour[:, second_index + 1:])) + # connect first and second half to create a closed cumulative loop + # that includes the start point of the contour (Posterior end) start_to_cutoff = np.hstack( - (first_intersection[:, None], contour[:, first_index:second_index], second_intersection[:, None]) + ( + contour[:, :first_index], + first_intersection[:, None], + second_intersection[:, None], + contour[:, second_index + 1 :], + ) ) - - # connect first and second half + # add cumulative subsegment split_contours.append(start_to_cutoff) else: raise ValueError("No intersections found, this should not happen") diff --git a/CorpusCallosum/shape/thickness.py b/CorpusCallosum/shape/thickness.py index def2ca37..ff6ecb05 100644 --- a/CorpusCallosum/shape/thickness.py +++ b/CorpusCallosum/shape/thickness.py @@ -68,9 +68,9 @@ def find_closest_edge(point, contour): int Index of the closest edge. """ - edges_start = contour[:-1, :2] # N-1 x 2 - edges_end = contour[1:, :2] # N-1 x 2 - edges_vec = edges_end - edges_start # N-1 x 2 + edges_start = contour[:, :2] # N x 2 + edges_end = np.roll(contour[:, :2], -1, axis=0) # N x 2 + edges_vec = edges_end - edges_start # N x 2 # Calculate projection coefficient for all edges at once # (p-a)·(b-a) / |b-a|² @@ -315,7 +315,7 @@ def cc_thickness( # keep track of start index if inserted_idx_start <= anterior_endpoint_idx: anterior_endpoint_idx += 1 - if inserted_idx_start >= posterior_endpoint_idx: + if inserted_idx_start <= posterior_endpoint_idx: posterior_endpoint_idx += 1 contour_2d, contour_thickness, inserted_idx_end = insert_point_with_thickness( @@ -324,7 +324,7 @@ def cc_thickness( # keep track of end index if inserted_idx_end <= anterior_endpoint_idx: anterior_endpoint_idx += 1 - if inserted_idx_end >= posterior_endpoint_idx: + if inserted_idx_end <= posterior_endpoint_idx: posterior_endpoint_idx += 1 contour_2d_with_thickness = np.concatenate([contour_2d, contour_thickness[:, None]], axis=1) diff --git a/CorpusCallosum/utils/visualization.py b/CorpusCallosum/utils/visualization.py index a71b3543..f40db0ec 100644 --- a/CorpusCallosum/utils/visualization.py +++ b/CorpusCallosum/utils/visualization.py @@ -201,9 +201,9 @@ def plot_contours( ax[current_plot].imshow(slice_or_slab[slice_or_slab.shape[0] // 2], cmap="gray") ax[current_plot].set_title(title) if _split_contours: - for i, this_contour in enumerate(_split_contours): + for this_contour in _split_contours: ax[current_plot].fill(this_contour[1, :], this_contour[0, :], color="steelblue", alpha=0.25) - kwargs = {"color": "mediumblue", "linewidth": 0.7, "linestyle": "solid" if i != 0 else "dotted"} + kwargs = {"color": "mediumblue", "linewidth": 0.7, "linestyle": "solid"} ax[current_plot].plot(this_contour[1, :], this_contour[0, :], **kwargs) if ac_coords_vox is not None: ax[current_plot].scatter(ac_coords_vox[1], ac_coords_vox[0], color="red", marker="x") diff --git a/doc/overview/modules/CC.md b/doc/overview/modules/CC.md index 5795136c..b08d5b88 100644 --- a/doc/overview/modules/CC.md +++ b/doc/overview/modules/CC.md @@ -18,19 +18,19 @@ The pipeline supports different analysis modes that determine the type of templa ### 3D Analysis -When running the main pipeline with `--slice_selection all` and `--save_template`, a complete 3D template is generated: +When running the main pipeline with `--slice_selection all` and `--save_template_dir`, a complete 3D template is generated: ```bash # Generate 3D template data -python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ +python3 fastsurfer_cc.py --sd /data/subjects --sid sub001 \ --slice_selection all \ - --save_template /data/templates/sub001 + --save_template_dir /data/templates/sub001 ``` This creates: -- `contours.txt`: Multi-slice contour data for 3D reconstruction -- `thickness_values.txt`: Thickness measurements across all slices -- `measurement_points.txt`: 3D vertex indices for thickness measurements +- `contour_.txt`: Multi-slice contour data for 3D reconstruction +- `thickness_values_.txt`: Thickness measurements across all slices +- `thickness_measurement_points_.txt`: 3D vertex indices for thickness measurements **Benefits:** - Enables volumetric thickness analysis @@ -41,13 +41,13 @@ For visualization instructions and outputs, see the [cc_visualization.py documen ### 2D Analysis -When using `--slice_selection middle` or a specific slice number with `--save_template`: +When using `--slice_selection middle` or a specific slice number with `--save_template_dir`: ```bash # Generate 2D template data (middle slice) -python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ +python3 fastsurfer_cc.py --sd /data/subjects --sid sub001 \ --slice_selection middle \ - --save_template /data/templates/sub001 + --save_template_dir /data/templates/sub001 ``` **Benefits:** diff --git a/doc/scripts/cc_visualization.rst b/doc/scripts/cc_visualization.rst index b280a953..e9e4c136 100644 --- a/doc/scripts/cc_visualization.rst +++ b/doc/scripts/cc_visualization.rst @@ -12,7 +12,7 @@ Usage Examples 3D Visualization ~~~~~~~~~~~~~~~~ -To visualize a 3D template generated by ``fastsurfer_cc.py`` (using ``--slice_selection all --save_template ...``), +To visualize a 3D template generated by ``fastsurfer_cc.py`` (using ``--slice_selection all --save_template_dir ...``), point the script to the exported template directory: .. code-block:: bash @@ -24,7 +24,7 @@ point the script to the exported template directory: 2D Visualization ~~~~~~~~~~~~~~~~ -To visualize a 2D template (using ``--slice_selection middle --save_template ...``): +To visualize a 2D template (using ``--slice_selection middle --save_template_dir ...``): .. code-block:: bash @@ -35,9 +35,7 @@ To visualize a 2D template (using ``--slice_selection middle --save_template ... .. note:: - You can still pass ``--contours``, ``--thickness`` and - ``--measurement_points`` directly when working with standalone files, but - ``--template_dir`` is the recommended way to load the multi-slice templates + The ``--template_dir`` is the required way to load the templates produced by ``fastsurfer_cc.py``. Outputs From f99adfd8449802e61669efffcbe096a6eb2b252e Mon Sep 17 00:00:00 2001 From: ClePol Date: Tue, 6 Jan 2026 16:46:43 +0100 Subject: [PATCH 68/68] fix 3d visualization --- CorpusCallosum/cc_visualization.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py index 714f44dd..a793ba44 100644 --- a/CorpusCallosum/cc_visualization.py +++ b/CorpusCallosum/cc_visualization.py @@ -222,15 +222,6 @@ def main( cc_mesh.plot_mesh(**plot_kwargs) cc_mesh.plot_mesh(output_path=str(output_dir / "cc_mesh.html"), **plot_kwargs) - # Here we need to load the np.linalg.inv(fsavg_vox2ras @ orig2fsavg_vox2vox) - # This is the same as orig2fsavg_ras2ras from cc_up.lta - # orig2fsavg_ras2ras = read_lta(output_dir / "mri/transforms/cc_up.lta") - # orig = nibabel.load(output_dir / "mri/orig.mgz") - # cc_mesh = cc_mesh.to_vox_coordinates(mesh_ras2vox=np.linalg.inv(orig2fsavg_ras2ras @ orig.affine)) - # If we are willing to screenshot here in fsavg space, this can be simplified to just fsavg_vox2ras - from CorpusCallosum.data.read_write import load_fsaverage_data - fsavg_vox2ras, _ = load_fsaverage_data(Path(__file__).parent / "data/fsaverage_data.json") - cc_mesh = cc_mesh.to_vox_coordinates(mesh_ras2vox=np.linalg.inv(fsavg_vox2ras)) logger.info(f"Writing vtk file to {output_dir / 'cc_mesh.vtk'}") cc_mesh.write_vtk(str(output_dir / "cc_mesh.vtk")) logger.info(f"Writing freesurfer surface file to {output_dir / 'cc_mesh.fssurf'}")