diff --git a/.gitignore b/.gitignore index 7c0db26..b7c01ec 100644 --- a/.gitignore +++ b/.gitignore @@ -1,10 +1,21 @@ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] -settings.json -_data/* -.DS_Store + +# Distribution / packaging *.egg *.egg-info/ uv.lock + +# IDE / editor settings +settings.json + +# macOS system files +.DS_Store + +# Data directories +_data/* + +# Logs and temporary files logs/* +temp/* diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..cc78c9e --- /dev/null +++ b/Makefile @@ -0,0 +1,19 @@ +# Makefile for ELK project + +# Default environment (production or development) +ELK_ENV ?= production + +# Main target: run the project +run: + @echo "Running in $(ELK_ENV) environment..." + @ELK_ENV=$(ELK_ENV) uv run src/main.py + +# Optional: create virtual environment using uv if needed +venv: + @echo "Ensuring virtual environment exists..." + @uv venv + +# Clean virtual environment (optional) +clean: + @echo "Removing virtual environment..." + @rm -rf venv diff --git a/README.md b/README.md index bd66fa3..7468242 100644 --- a/README.md +++ b/README.md @@ -38,15 +38,18 @@ pip install uv ## Running the App -Once `uv` is installed, navigate to the project's root directory and run: +Once the project is set up, navigate to the project root directory and use the provided Makefile to run the app: -## Running the App +```sh +# Run with the default environment (production) +make run -Once `uv` is installed, navigate to the project's root directory and run: +# Run in development environment +make run ELK_ENV=development -```sh -uv run src/main.py -uv run src/main.py +### Optional: +make venv # Creates/ensures the virtual environment exists +make clean # Clean the virtual environment ``` On the first run, this command creates a virtual environment, installs all dependencies, and starts the app. This may take a moment but only happens the first time. @@ -75,7 +78,7 @@ For help and support, please contact: ## Version -The current version of the package is **0.7.1**. +The current version of the package is **0.8.0**. ## License @@ -240,7 +243,3 @@ Ensure your pull request includes: - A clear description of what the changes do and why they are necessary. - Any relevant issue numbers. - ---- - -This `README.md` provides an overview of the ELK package, including its key features, installation instructions, supported formats, authorship, contact information, current version, and license details. For any additional information or assistance, please reach out to the provided contact emails. diff --git a/elk.command b/elk.command deleted file mode 100755 index 88f50f2..0000000 --- a/elk.command +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash -cd "$(dirname "$0")" -uv venv -uv run python src/main.py \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d62c43f..c6b4673 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,24 +1,28 @@ [project] -name = "electrodelocalizationkit" -version = "0.3.1" +name = "electrode-localization-kit" +version = "0.8.0" description = "A tool for rapid EEG electrode localization using IR stereo cameras or MRI." readme = "README.md" -requires-python = ">=3.11, <3.12" -license = { file = "LICENSE" } +license = "GPL-3.0-or-later" +license-files = ["LICEN[CS]E"] +requires-python = "==3.11.*" dependencies = [ - "nibabel==5.2.1", + "vtk==9.2.6", + "PyQt6==6.7.0", "numpy==1.25.0", - "opencv_python==4.8.1.78", "pandas==2.2.2", - "PyQt6==6.7.0", - "PyQt6_sip==13.6.0", + "scipy==1.15.3", + "nibabel==5.2.1", "vedo==2023.4.6", - "vtk==9.2.6" + "PyQt6_sip==13.6.0", + "scikit_learn==1.7.1", + "scikit_image==0.25.2", + "opencv_python==4.8.1.78", ] [tool.black] line-length = 100 [tool.ruff] -line-length = 100 \ No newline at end of file +line-length = 100 diff --git a/requirements.txt b/requirements.txt index f894dfe..54faaea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,42 +8,60 @@ deprecated==1.2.18 # via vedo fonttools==4.56.0 # via matplotlib +imageio==2.37.0 + # via scikit-image +joblib==1.5.2 + # via scikit-learn kiwisolver==1.4.8 # via matplotlib +lazy-loader==0.4 + # via scikit-image matplotlib==3.10.0 # via vtk +networkx==3.5 + # via scikit-image nibabel==5.2.1 - # via electrodelocalizationkit (pyproject.toml) + # via electrode-localization-kit (pyproject.toml) numpy==1.25.0 # via - # electrodelocalizationkit (pyproject.toml) + # electrode-localization-kit (pyproject.toml) # contourpy + # imageio # matplotlib # nibabel # opencv-python # pandas + # scikit-image + # scikit-learn + # scipy + # tifffile # vedo opencv-python==4.8.1.78 - # via electrodelocalizationkit (pyproject.toml) + # via electrode-localization-kit (pyproject.toml) packaging==24.2 # via + # lazy-loader # matplotlib # nibabel + # scikit-image pandas==2.2.2 - # via electrodelocalizationkit (pyproject.toml) + # via electrode-localization-kit (pyproject.toml) pillow==11.1.0 - # via matplotlib + # via + # imageio + # matplotlib + # scikit-image pygments==2.19.1 # via vedo pyparsing==3.2.1 # via matplotlib pyqt6==6.7.0 - # via electrodelocalizationkit (pyproject.toml) + # via electrode-localization-kit (pyproject.toml) pyqt6-qt6==6.7.3 # via pyqt6 pyqt6-sip==13.6.0 # via - # electrodelocalizationkit (pyproject.toml) + # electrode-localization-kit (pyproject.toml) # pyqt6 python-dateutil==2.9.0.post0 # via @@ -51,15 +69,28 @@ python-dateutil==2.9.0.post0 # pandas pytz==2025.1 # via pandas +scikit-image==0.25.2 + # via electrode-localization-kit (pyproject.toml) +scikit-learn==1.7.1 + # via electrode-localization-kit (pyproject.toml) +scipy==1.15.3 + # via + # electrode-localization-kit (pyproject.toml) + # scikit-image + # scikit-learn six==1.17.0 # via python-dateutil +threadpoolctl==3.6.0 + # via scikit-learn +tifffile==2025.10.16 + # via scikit-image tzdata==2025.1 # via pandas vedo==2023.4.6 - # via electrodelocalizationkit (pyproject.toml) + # via electrode-localization-kit (pyproject.toml) vtk==9.2.6 # via - # electrodelocalizationkit (pyproject.toml) + # electrode-localization-kit (pyproject.toml) # vedo wrapt==1.17.2 # via deprecated diff --git a/src/config/sizes.py b/src/config/sizes.py index 015aa73..182253c 100644 --- a/src/config/sizes.py +++ b/src/config/sizes.py @@ -2,15 +2,16 @@ Size configuration for the electrodes and flagposts. """ + class ElectrodeSizes: - HEADSCAN_ELECTRODE_SIZE = 0.02 + HEADSCAN_ELECTRODE_SIZE = 3.5 MRI_ELECTRODE_SIZE = 0.02 LABEL_ELECTRODE_SIZE = 0.04 - + HEADSCAN_FLAGPOST_SIZE = 0.6 MRI_FLAGPOST_SIZE = 0.6 LABEL_FLAGPOST_SIZE = 0.6 - - HEADSCAN_FLAGPOST_HEIGHT = 0.05 + + HEADSCAN_FLAGPOST_HEIGHT = 5.0 MRI_FLAGPOST_HEIGHT = 0.05 - LABEL_FLAGPOST_HEIGHT = 0.05 \ No newline at end of file + LABEL_FLAGPOST_HEIGHT = 0.05 diff --git a/src/data/loader.py b/src/data/loader.py index 409abcc..859c159 100644 --- a/src/data/loader.py +++ b/src/data/loader.py @@ -6,10 +6,13 @@ import vedo as vd import nibabel as nib +from processing_models.mesh.mesh_loader import MeshLoader -def load_head_surface_mesh_from_file(filename: str) -> vd.Mesh: + +def load_head_surface_mesh_from_file(surface_file: str, texture_file: str = None) -> vd.Mesh: """Loads a head surface mesh from a file.""" - return vd.Mesh(filename) + mesh_loader = MeshLoader(surface_file, texture_file) + return mesh_loader.mesh_preprocessed.clone() def load_mri_surface_mesh_from_file(filename: str) -> vd.Mesh: diff --git a/src/data_models/cap_model.py b/src/data_models/cap_model.py index b3fe23a..4cd4805 100644 --- a/src/data_models/cap_model.py +++ b/src/data_models/cap_model.py @@ -117,15 +117,15 @@ def insert_electrode(self, electrode: Electrode, parent=QModelIndex()) -> None: electrode.coordinates, electrode.modality, include_fiducials=True ) - too_close_electrodes = [ - d[0] - for d in distances - if ( - d[1] <= ElectrodeSizes.HEADSCAN_ELECTRODE_SIZE / 2 - or d[1] <= ElectrodeSizes.MRI_ELECTRODE_SIZE / 2 - or d[1] <= ElectrodeSizes.LABEL_ELECTRODE_SIZE / 2 - ) - ] + min_distance = -1 + if electrode.modality == ModalitiesMapping.HEADSCAN: + min_distance = ElectrodeSizes.HEADSCAN_ELECTRODE_SIZE / 2 + elif electrode.modality == ModalitiesMapping.MRI: + min_distance = ElectrodeSizes.MRI_ELECTRODE_SIZE / 2 + elif electrode.modality == ModalitiesMapping.REFERENCE: + min_distance = ElectrodeSizes.LABEL_ELECTRODE_SIZE / 2 + + too_close_electrodes = [d[0] for d in distances if (d[1] <= min_distance)] if len(too_close_electrodes) > 0: return diff --git a/src/data_models/head_models.py b/src/data_models/head_models.py index 5ca1393..64e7eee 100644 --- a/src/data_models/head_models.py +++ b/src/data_models/head_models.py @@ -37,16 +37,16 @@ def __init__(self, surface_file: str, texture_file: str | None = None): self.texture_file = texture_file self.mesh = None - self.mesh = load_head_surface_mesh_from_file(surface_file) + self.mesh = load_head_surface_mesh_from_file(surface_file, texture_file) self.modality = ModalitiesMapping.HEADSCAN self.fiducials = [] - self.normalization_scale = 1 + self.normalization_scale = 1000 self._registered = False - self.normalize() + # self.normalize() self.apply_texture() @@ -56,9 +56,12 @@ def normalize(self): def rescale_to_original_size(self): self.normalization_scale = rescale_to_original_size(self.mesh, self.normalization_scale) # type: ignore - def apply_texture(self): + def apply_texture(self, texture_file: str | None = None): if self.texture_file is not None: self.mesh = self.mesh.texture(self.texture_file) # type: ignore + elif texture_file is not None: + self.texture_file = texture_file + self.mesh = self.mesh.texture(texture_file) def register_mesh(self, surface_registrator: BaseSurfaceRegistrator) -> np.ndarray: transform_matrix = surface_registrator.register() # type: ignore diff --git a/src/fileio/locations.py b/src/fileio/locations.py index 53db661..1887d9a 100644 --- a/src/fileio/locations.py +++ b/src/fileio/locations.py @@ -28,7 +28,7 @@ def load_locations( # model.remove_electrode_by_id(electrode_id) if ENV == "development": - files["locations"] = "sample_data/measured_electrodes.ced" + files["locations"] = "sample_data/electrode_locations.ced" else: file_path, _ = QFileDialog.getOpenFileName( None, diff --git a/src/fileio/scan.py b/src/fileio/scan.py index b3884ee..972286f 100644 --- a/src/fileio/scan.py +++ b/src/fileio/scan.py @@ -6,6 +6,7 @@ from data_models.cap_model import CapModel from data_models.head_models import HeadScan import os +from ui.pyloc_main_window import Ui_ELK ENV = os.getenv("ELK_ENV", "production") @@ -17,6 +18,7 @@ def load_surface( headmodels: dict, frames: list[tuple[str, QFrame]], model: CapModel, + ui: Ui_ELK = None, ): if ENV == "development": files["scan"] = "sample_data/model_mesh.obj" @@ -34,6 +36,7 @@ def load_surface( headmodels["scan"], frame, model, + ui, ) @@ -44,6 +47,7 @@ def load_texture( frames: list[tuple[str, QFrame]], model: CapModel, electrode_detector: BaseElectrodeDetector | None, + ui: Ui_ELK = None, ): if ENV == "development": files["texture"] = "sample_data/model_texture.jpg" @@ -60,13 +64,17 @@ def load_texture( if electrode_detector: electrode_detector.apply_texture(files["texture"]) - headmodels["scan"] = HeadScan(files["scan"], files["texture"]) + if headmodels.get("scan") is not None: + headmodels["scan"].apply_texture(files["texture"]) + else: + headmodels["scan"] = HeadScan(files["scan"], files["texture"]) for label, frame in frames: views[label] = create_surface_view( headmodels["scan"], frame, model, + ui, ) @@ -74,6 +82,7 @@ def create_surface_view( head_scan: HeadScan, frame: QFrame, model: CapModel, + ui: Ui_ELK = None, ) -> SurfaceView | None: config = { "sphere_size": ElectrodeSizes.HEADSCAN_ELECTRODE_SIZE, @@ -88,6 +97,7 @@ def create_surface_view( [head_scan.modality], config, model, + ui=ui, ) return surface_view diff --git a/src/main.py b/src/main.py index bc772a4..df3561b 100644 --- a/src/main.py +++ b/src/main.py @@ -30,6 +30,7 @@ from ui.callbacks.refresh import refresh_views_on_resize from ui.callbacks.connect.connect_fileio import connect_fileio_buttons from ui.callbacks.connect.connect_texture import connect_texture_buttons +from ui.callbacks.connect.connect_detection import connect_detection_buttons from ui.callbacks.connect.connect_configuration_boxes import connect_configuration_boxes from ui.callbacks.connect.connect_sliders import connect_alpha_sliders from ui.callbacks.connect.connect_model import connect_model @@ -81,10 +82,10 @@ def __init__(self, parent=None): connect_model(self) connect_fileio_buttons(self) connect_texture_buttons(self) + connect_detection_buttons(self) connect_scan_mri_alignment_buttons(self) connect_display_secondary_mesh_checkbox(self) connect_configuration_boxes(self) - connect_texture_buttons(self) connect_alpha_sliders(self) connect_tab_changed(self) connect_splitter_moved(self) @@ -138,6 +139,10 @@ def set_data_containers(self): "scan": None, "mri": None, } + self.loaders = { + "mesh": None, + "view": None, + } self.images = {"dog": None, "hough": None} diff --git a/src/processing_handlers/detection_processing.py b/src/processing_handlers/detection_processing.py new file mode 100644 index 0000000..25b069c --- /dev/null +++ b/src/processing_handlers/detection_processing.py @@ -0,0 +1,172 @@ +import os +import logging +import numpy as np +import pandas as pd + +from ui.pyloc_main_window import Ui_ELK +from data_models.cap_model import CapModel +from data_models.electrode import Electrode +from data_models.head_models import HeadScan +from config.logger_config import setup_logger +from processing_models.mesh import MeshLoader +from view.interactive_surface_view import InteractiveSurfaceView +from processing_models.electrode import ( + DetectionMethod, + ElectrodeMapper, + ProcessingParams, + ViewType, + ViewLoader, +) + + +setup_logger(logging.INFO) +logger = logging.getLogger(__name__) + +TEMP_DIR = "./temp" +VIEWS_PATH = "./temp" +FIDUCIALS_PATH = "./temp/fiducials.csv" +ELECTRODES_PATH = "./temp/electrodes.csv" + +FIDUCIALS_REQUIRED = {"NAS", "INI", "LPA", "RPA", "VTX"} +VIEW_TYPES = [ + ViewType.FRONT, + ViewType.BACK, + ViewType.TOP, + ViewType.RIGHT, + ViewType.LEFT, + ViewType.FRONT_TOP, + ViewType.BACK_TOP, + ViewType.FRONT_RIGHT, + ViewType.FRONT_LEFT, + ViewType.BACK_RIGHT, + ViewType.BACK_LEFT, + ViewType.TOP_RIGHT, + ViewType.TOP_LEFT, +] +DETECTION_METHODS = [ + DetectionMethod.TRADITIONAL, + DetectionMethod.FRST, +] + + +def process_mesh( + view: InteractiveSurfaceView, + headmodel: HeadScan, + model: CapModel, + loaders: dict[str, MeshLoader | ViewLoader], + ui: Ui_ELK = None, +): + fiducials = model.get_fiducials([headmodel.modality]) + + # Save fiducials to file + os.makedirs(TEMP_DIR, exist_ok=True) + with open(FIDUCIALS_PATH, "w", encoding="utf-8") as f: + for fid in fiducials: + if ( + fid is not None + and fid.label is not None + and fid.label in FIDUCIALS_REQUIRED + and fid.coordinates is not None + ): + f.write( + f"{fid.label},{fid.coordinates[0]},{fid.coordinates[1]},{fid.coordinates[2]}\n" + ) + logger.info(f"Fiducials saved to {FIDUCIALS_PATH}") + + # Process mesh + mesh_loader = MeshLoader(headmodel.surface_file, headmodel.texture_file, FIDUCIALS_PATH) + loaders["mesh"] = mesh_loader + mesh_loader.clean_data(x_margin=0.5, y_top_margin=0.25, y_bottom_margin=1.0, z_margin=0.25) + mesh_loader.extract_cap_data(margin=0.0) + mesh_loader.capture_data(VIEWS_PATH) + + # Transform fiducials back + transformed_fiducials = pd.DataFrame( + transform_points( + mesh_loader.fiducials.to_numpy(), + mesh_loader.aligner.source_origin, + mesh_loader.aligner.rotation_matrix, + inverse=True, + ), + index=mesh_loader.fiducials.index, + columns=mesh_loader.fiducials.columns, + ) + + # Update fiducial locations + for i in range(len(fiducials)): + if fiducials[i].label is not None and fiducials[i].label in transformed_fiducials.index: + point = transformed_fiducials.loc[fiducials[i].label].values + fiducials[i].coordinates = point + + # Rerender + ui and ui.detect_button.setEnabled(True) + view.render_electrodes() + + +def detect_electrodes( + view: InteractiveSurfaceView, + headmodel: HeadScan, + model: CapModel, + loaders: dict[str, MeshLoader | ViewLoader], +): + mesh_loader = loaders["mesh"] + view_loader = ViewLoader(VIEWS_PATH, ProcessingParams()) + loaders["view"] = view_loader + + # Detect markers and electrodes (in parallel) + view_loader.detect_markers_and_electrodes(VIEW_TYPES, DETECTION_METHODS) + + # Label markers + view_loader.label_markers(VIEW_TYPES) + + # Map electrodes + mapper = ElectrodeMapper( + mesh_loader.mesh_cleaned, + mesh_loader.mesh_extracted, + mesh_loader.fiducials, + mesh_loader.aligner.origin, + ) + mapper.map_electrodes_to_3d(view_loader.detected, view_loader.metadata) + + # Transform electrodes back + transformed_electrodes = mapper.electrodes.copy() + transformed_electrodes[:, (2, 3, 4)] = transform_points( + transformed_electrodes[:, (2, 3, 4)], + mesh_loader.aligner.source_origin, + mesh_loader.aligner.rotation_matrix, + inverse=True, + ) + mapper.save(ELECTRODES_PATH, transformed_electrodes) + electrodes = pd.DataFrame( + transformed_electrodes[:, (2, 3, 4, 5)], + columns=["x", "y", "z", "label"], + index=[method.value for method in mapper.electrodes[:, 1]], + ) + + # Project electrodes on map + for _, row in electrodes.iterrows(): + point = (row["x"], row["y"], row["z"]) + label = row["label"] if pd.notna(row["label"]) else "None" + labeled = pd.notna(row["label"]) + model.insert_electrode( + Electrode(point, modality=headmodel.modality, label=label, labeled=labeled) + ) + view.render_electrodes() + + +def transform_points(points, origin, rotation_matrix, inverse=False): + """ + Transform 3D points between coordinate systems. + X' = (X - O) @ R^T -> X = X' @ R + O + Rotation matrix: R^T @ R = R @ R^T = I -> R^-1 = R^T + """ + points = np.asarray(points.copy()) + origin = np.asarray(origin.copy()) + R = np.asarray(rotation_matrix.copy()) + + if not inverse: + transformed = (points - origin) @ R.T + else: + transformed = points @ R + origin + + return transformed diff --git a/src/processing_models/electrode/__init__.py b/src/processing_models/electrode/__init__.py new file mode 100644 index 0000000..582a064 --- /dev/null +++ b/src/processing_models/electrode/__init__.py @@ -0,0 +1,47 @@ +from .basic_electrode_detector import BasicElectrodeDetector +from .color_space import ColorSpace +from .detection_method import DetectionMethod +from .electrode_detector import ElectrodeDetector +from .electrode_mapper import ElectrodeMapper +from .electrode_merger import ElectrodeMerger +from .frst import FRST +from .marker_detector import MarkerDetector +from .marker_labeler import MarkerLabeler +from .view_loader import ViewLoader +from .view_type import ViewType + +from .params import ( + ProcessingParams, +) + +from .util import ( + BackgroundMaskUtil, + ColorEnhancementUtil, + ColorQuantizationUtil, + IlluminationCorrectionUtil, + NoiseReductionUtil, + SharpeningUtil, +) + +__all__ = [ + "BasicElectrodeDetector", + "ColorSpace", + "DetectionMethod", + "ElectrodeDetector", + "FRST", + "ElectrodeMapper", + "ElectrodeMerger", + "MarkerDetector", + "MarkerLabeler", + "ViewLoader", + "ViewType", + # Params + "ProcessingParams", + # Utils + "BackgroundMaskUtil", + "ColorEnhancementUtil", + "ColorQuantizationUtil", + "IlluminationCorrectionUtil", + "NoiseReductionUtil", + "SharpeningUtil", +] diff --git a/src/processing_models/electrode/basic_electrode_detector.py b/src/processing_models/electrode/basic_electrode_detector.py new file mode 100644 index 0000000..4ceb53d --- /dev/null +++ b/src/processing_models/electrode/basic_electrode_detector.py @@ -0,0 +1,363 @@ +import cv2 +import logging +import numpy as np + +from typing import List, Optional, Tuple + +from .frst import FRST +from .view_type import ViewType +from .color_space import ColorSpace +from .detection_method import DetectionMethod +from .util import ColorEnhancementUtil, NoiseReductionUtil + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +class BasicElectrodeDetector: + + def __init__( + self, + min_area: Optional[int] = None, + min_distance: Optional[int] = None, + min_radius: Optional[int] = None, + max_radius: Optional[int] = None, + ) -> None: + """ + Initialize basic electrode detector. + """ + self.min_area = min_area + self.min_distance = min_distance + self.min_radius = min_radius + self.max_radius = max_radius + + def _preprocess_image(self, image: np.ndarray) -> np.ndarray: + """ + Reduce noise, and normalize image to grayscale. + """ + # Reduce the green color markers + image = ColorEnhancementUtil.enhance_green_in_hsv(image, 60.0, 25.0, -2.5, 1.5) + + hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32) + h, s, v = cv2.split(hsv) + + # Darken very dark pixels further to boost contrast + dark_mask = v < 128 + v[dark_mask] = np.clip(v[dark_mask] * 0.5, 0, 255) + hsv_boosted = cv2.merge([h, s, v]) + rgb_boosted = cv2.cvtColor(hsv_boosted.astype(np.uint8), cv2.COLOR_HSV2RGB) + + # Denoise and convert to grayscale + denoised = NoiseReductionUtil.nlm_denoising(rgb_boosted, 25.0, ColorSpace.LAB) + grayscale = cv2.cvtColor(denoised, cv2.COLOR_RGB2GRAY) + + # Normalize intensities to 0–255 + normalized = cv2.normalize(grayscale, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) + + return normalized + + def _apply_morphological_operations(self, binary_mask: np.ndarray) -> np.ndarray: + """ + Clean up binary mask with opening and closing operations. + """ + kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)) + kernel_large = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) + + # Remove small noise with opening + opened = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel_small) + + # Fill small holes with closing + closed = cv2.morphologyEx(opened, cv2.MORPH_CLOSE, kernel_large) + + # Make structures a bit smaller again + reopened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel_small) + + # Refine shapes with erosion-dilation + eroded = cv2.erode(reopened, kernel_small, iterations=1) + refined = cv2.dilate(eroded, kernel_large, iterations=1) + + return refined + + def _filter_contours( + self, + morph_mask: np.ndarray, + view_type: ViewType, + use_frst: bool = False, + ) -> np.ndarray: + """ + Keep only contours that look like circular electrodes. + """ + contours, _ = cv2.findContours(morph_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + filtered = np.zeros_like(morph_mask) + + for contour in contours: + area = cv2.contourArea(contour) + min_area = self.min_area or view_type.electrode_cfg.min_area or 500 + if area < min_area: + continue + + perimeter = cv2.arcLength(contour, True) + if perimeter == 0: + continue + + circularity = 4 * np.pi * area / (perimeter * perimeter) + min_circularity = 0.5 if use_frst else 0.75 + if circularity <= min_circularity: + continue + + cv2.drawContours(filtered, [contour], -1, 255, -1) + + return filtered + + def get_electrode_mask( + self, + image: np.ndarray, + foreground_mask: np.ndarray, + view_type: ViewType, + methods: List[DetectionMethod], + ) -> np.ndarray: + """ + Detect electrode regions using one or more detection methods and combine their results into a single mask. + """ + if not methods: + # Return full mask if no methods are provided + return np.ones(image.shape[:2], dtype=np.uint8) * 255 + + # Start with an empty mask + electrode_mask = np.zeros(image.shape[:2], dtype=np.uint8) + + if DetectionMethod.TRADITIONAL in methods: + grayscale = self._preprocess_image(image) + _, binary = cv2.threshold(grayscale, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) + binary = cv2.bitwise_and(binary, binary, mask=foreground_mask) + morph_mask = self._apply_morphological_operations(binary) + traditional_mask = self._filter_contours(morph_mask, view_type, use_frst=False) + electrode_mask = cv2.bitwise_or(electrode_mask, traditional_mask) + + if DetectionMethod.FRST in methods: + grayscale = self._preprocess_image(image) + tophat_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) + tophat = cv2.morphologyEx(grayscale, cv2.MORPH_TOPHAT, tophat_kernel) + frst_input = cv2.subtract(grayscale, tophat) + + radii_range = list(range(15, 50, 5)) + frst = FRST() + frst_response = frst.run(frst_input, radii_range, alpha=1.0, beta=0.1) + + eroded_mask = cv2.erode( + foreground_mask, + cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)), + iterations=1, + ) + frst_response = cv2.bitwise_and( + frst_response, frst_response, mask=eroded_mask.astype(np.uint8) + ) + + frst_normalized = cv2.normalize(frst_response, None, 0, 255, cv2.NORM_MINMAX).astype( + np.uint8 + ) + _, frst_binary = cv2.threshold( + frst_normalized, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU + ) + frst_binary = cv2.bitwise_and(frst_binary, frst_binary, mask=foreground_mask) + + morph_mask = self._apply_morphological_operations(frst_binary) + frst_mask = self._filter_contours(morph_mask, view_type, use_frst=True) + electrode_mask = cv2.bitwise_or(electrode_mask, frst_mask) + + return electrode_mask + + def _apply_dog_and_hough( + self, + original_image: np.ndarray, + grayscale: np.ndarray, + filtered_mask: np.ndarray, + view_type: ViewType, + use_frst: bool = False, + ) -> Optional[np.ndarray]: + """ + Detect circles using Difference of Gaussians + Hough Circle Transform. + """ + # Apply DoG filtering + masked = cv2.bitwise_and(grayscale, grayscale, mask=filtered_mask) + blur_small = cv2.GaussianBlur(masked.astype(np.float32), (0, 0), 2.5) + blur_large = cv2.GaussianBlur(masked.astype(np.float32), (0, 0), 5.0) + dog = blur_small - blur_large + + dog_normalized = cv2.normalize(dog, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) + dog_filtered = cv2.medianBlur(dog_normalized, 5) + + # Hough Circle parameters + param2 = 25 if use_frst else 30 + min_dist = self.min_distance or view_type.electrode_cfg.min_distance or 50 + min_r = self.min_radius or view_type.electrode_cfg.min_radius or 10 + max_r = self.max_radius or view_type.electrode_cfg.max_radius or 25 + + circles = cv2.HoughCircles( + dog_filtered, + cv2.HOUGH_GRADIENT, + dp=1.0, + minDist=min_dist, + param1=50, + param2=param2, + minRadius=min_r, + maxRadius=max_r, + ) + + # Validate detected circles + if circles is not None: + circles = np.uint16(np.around(circles)) + valid_circles = [] + + gray_original = cv2.cvtColor(original_image, cv2.COLOR_RGB2GRAY) + + for x, y, r in circles[0, :]: + if y >= filtered_mask.shape[0] or x >= filtered_mask.shape[1]: + continue + + if filtered_mask[y, x] > 0 and gray_original[y, x] < 200: + valid_circles.append((x, y, r)) + + circles = np.array([valid_circles], dtype=np.uint16) if valid_circles else None + + return circles + + def detect_traditional( + self, + image: np.ndarray, + foreground_mask: np.ndarray, + view_type: ViewType, + ) -> Optional[np.ndarray]: + """ + Detect electrodes using thresholding + morphology + Hough transform. + """ + grayscale = self._preprocess_image(image) + + _, binary = cv2.threshold(grayscale, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) + binary = cv2.bitwise_and(binary, binary, mask=foreground_mask) + + morph_mask = self._apply_morphological_operations(binary) + electrode_mask = self._filter_contours(morph_mask, view_type, use_frst=False) + + return self._apply_dog_and_hough( + image, grayscale, electrode_mask, view_type, use_frst=False + ) + + def detect_frst( + self, + image: np.ndarray, + foreground_mask: np.ndarray, + view_type: ViewType, + ) -> Optional[np.ndarray]: + """ + Detect electrodes using FRST (Fast Radial Symmetry Transform). + """ + grayscale = self._preprocess_image(image) + + # Enhance circular structures + tophat_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) + tophat = cv2.morphologyEx(grayscale, cv2.MORPH_TOPHAT, tophat_kernel) + frst_input = cv2.subtract(grayscale, tophat) + + radii_range = list(range(15, 50, 5)) + frst = FRST() + frst_response = frst.run(frst_input, radii_range, alpha=1.0, beta=0.1) + + # Apply foreground mask + eroded_mask = cv2.erode( + foreground_mask, + cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)), + iterations=1, + ) + frst_response = cv2.bitwise_and( + frst_response, frst_response, mask=eroded_mask.astype(np.uint8) + ) + + frst_normalized = cv2.normalize(frst_response, None, 0, 255, cv2.NORM_MINMAX).astype( + np.uint8 + ) + _, frst_binary = cv2.threshold(frst_normalized, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + frst_binary = cv2.bitwise_and(frst_binary, frst_binary, mask=foreground_mask) + + morph_mask = self._apply_morphological_operations(frst_binary) + electrode_mask = self._filter_contours(morph_mask, view_type, use_frst=True) + + return self._apply_dog_and_hough(image, grayscale, electrode_mask, view_type, use_frst=True) + + def _combine_results( + self, + view_type: ViewType, + circles_traditional: Optional[np.ndarray], + circles_frst: Optional[np.ndarray], + ) -> Optional[np.ndarray]: + """ + Merge circles from both methods, removing overlaps. + """ + combined = [] + + if circles_traditional is not None: + combined.extend( + (float(x), float(y), float(r), "traditional") for x, y, r in circles_traditional[0] + ) + if circles_frst is not None: + combined.extend((float(x), float(y), float(r), "frst") for x, y, r in circles_frst[0]) + + if not combined: + return None + + if len({m for _, _, _, m in combined}) == 1: + return np.array([[(x, y, r) for x, y, r, _ in combined]], dtype=np.uint16) + + # Remove overlaps, keep larger circles + unique = [] + for x1, y1, r1, method1 in sorted(combined, key=lambda c: c[2], reverse=True): + overlap = False + for i, (x2, y2, r2, method2) in enumerate(unique): + dist = np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) + min_dist = self.min_distance or view_type.electrode_cfg.min_distance or 0 + if dist < (r1 + r2) or dist < min_dist: + if r1 > r2: + unique[i] = (x1, y1, r1, method1) + overlap = True + break + if not overlap: + unique.append((x1, y1, r1, method1)) + + return np.array([[(x, y, r) for x, y, r, _ in unique]], dtype=np.uint16) if unique else None + + @staticmethod + def draw_circles(image: np.ndarray, circles: np.ndarray) -> np.ndarray: + """ + Draw detected circles and their centers. + """ + output = image.copy() + if circles is not None: + for x, y, r in np.uint16(np.around(circles))[0, :]: + cv2.circle(output, (x, y), r, (0, 255, 0), 3) + cv2.circle(output, (x, y), 2, (255, 0, 0), 5) + return output + + def detect( + self, + image: np.ndarray, + foreground_mask: np.ndarray, + view_type: ViewType, + methods: List[DetectionMethod], + ) -> Tuple[Optional[np.ndarray], np.ndarray]: + """ + Detect electrodes using selected methods and combine results. + """ + circles_traditional, circles_frst = None, None + + if DetectionMethod.TRADITIONAL in methods: + logging.info("Running traditional electrode detection...") + circles_traditional = self.detect_traditional(image, foreground_mask, view_type) + + if DetectionMethod.FRST in methods: + logging.info("Running FRST electrode detection...") + circles_frst = self.detect_frst(image, foreground_mask, view_type) + + circles = self._combine_results(view_type, circles_traditional, circles_frst) + output_image = self.draw_circles(image, circles) + + return circles, output_image diff --git a/src/processing_models/electrode/color_space.py b/src/processing_models/electrode/color_space.py new file mode 100644 index 0000000..e649613 --- /dev/null +++ b/src/processing_models/electrode/color_space.py @@ -0,0 +1,58 @@ +import cv2 +import numpy as np + +from enum import Enum +from typing import List, Dict + + +class ColorSpace(Enum): + """ + Possible color spaces for electrode detection. + """ + + RGB = "RGB" + HSV = "HSV" + LAB = "LAB" + YUV = "YUV" + GRAY = "GRAY" + + @staticmethod + def transform_color_spaces( + image: np.ndarray, + color_spaces: List["ColorSpace"], + ) -> Dict["ColorSpace", np.ndarray]: + """ + Transform an image to multiple color spaces with robust error handling. + """ + # Color space conversion mapping + conversion_map = { + ColorSpace.HSV: (cv2.COLOR_RGB2HSV, 3), + ColorSpace.LAB: (cv2.COLOR_RGB2LAB, 3), + ColorSpace.YUV: (cv2.COLOR_RGB2YUV, 3), + ColorSpace.GRAY: (cv2.COLOR_RGB2GRAY, 1), + } + + transformed: Dict[ColorSpace, np.ndarray] = {} + for color_space in color_spaces: + if color_space == ColorSpace.RGB: + # Keep the original image + transformed[color_space] = image + else: + try: + conv_code, num_channels = conversion_map[color_space] + converted = cv2.cvtColor(image, conv_code) + + # Check channels if needed + if num_channels == 1 and converted.ndim != 2: + raise ValueError(f"Conversion to {color_space} failed to produce grayscale") + if num_channels == 3 and (converted.ndim != 3 or converted.shape[2] != 3): + raise ValueError( + f"Conversion to {color_space} failed to produce 3 channels" + ) + + transformed[color_space] = converted + + except KeyError: + raise ValueError(f"Unsupported color space: {color_space}") + + return transformed diff --git a/src/processing_models/electrode/detection_method.py b/src/processing_models/electrode/detection_method.py new file mode 100644 index 0000000..d3cfad0 --- /dev/null +++ b/src/processing_models/electrode/detection_method.py @@ -0,0 +1,14 @@ +from enum import Enum + + +class DetectionMethod(Enum): + """ + Possible cap electrode/marker detection method implementations. + """ + + TRADITIONAL = "TRADITIONAL" + FRST = "FRST" + + MARKER = "MARKER" + ELECTRODE_BASIC = "ELECTRODE_BASIC" + ELECTRODE = "ELECTRODE" diff --git a/src/processing_models/electrode/electrode_detector.py b/src/processing_models/electrode/electrode_detector.py new file mode 100644 index 0000000..1d3e922 --- /dev/null +++ b/src/processing_models/electrode/electrode_detector.py @@ -0,0 +1,524 @@ +import cv2 +import logging +import numpy as np + +from sklearn.mixture import GaussianMixture +from sklearn.preprocessing import StandardScaler +from typing import Dict, List, Optional, Tuple, Union +from skimage.segmentation import slic, mark_boundaries + +from .frst import FRST +from .view_type import ViewType +from .color_space import ColorSpace +from .detection_method import DetectionMethod +from .basic_electrode_detector import BasicElectrodeDetector +from .util import ColorEnhancementUtil, ColorQuantizationUtil, NoiseReductionUtil + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +class ElectrodeDetector: + + def __init__( + self, + min_distance: Optional[int] = None, + n_segments: Optional[int] = None, + ): + """ + Initialize the superpixel segmentation electrode detector. + """ + self.min_distance = min_distance + self.n_segments = n_segments + + def _preprocess_image(self, image: np.ndarray, foreground_mask: np.ndarray) -> np.ndarray: + """ + Preprocess the input image for electrode detection. + """ + logging.info("Preprocessign image for electrode detection...") + + # Reduce green color markers + image = ColorEnhancementUtil.enhance_green_in_hsv(image, 60.0, 25.0, -2.5, 1.5) + + # Boost contrast in dark regions + hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32) + h, s, v = cv2.split(hsv) + dark_mask = v < 128 + v[dark_mask] = np.clip(v[dark_mask] * 0.5, 0, 255) + hsv_boosted = cv2.merge([h, s, v]) + rgb_boosted = cv2.cvtColor(hsv_boosted.astype(np.uint8), cv2.COLOR_HSV2RGB) + + # Color quantization in LAB + HSV spaces + processed = ColorQuantizationUtil.lab_hsv_quantization(image, 8, foreground_mask) + + # Non-local means denoising in LAB color space + denoised = NoiseReductionUtil.nlm_denoising(processed, 50.0, ColorSpace.LAB) + + # Normalize intensities + normalized = cv2.normalize(denoised, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) + + return normalized + + def _apply_slic_segmentation( + self, image: np.ndarray, foreground_mask: np.ndarray, view_type: ViewType + ) -> np.ndarray: + """ + Apply SLIC superpixel segmentation to the image. + """ + n_segments = self.n_segments or view_type.electrode_cfg.n_segments or 750 + logging.info(f"Applying SLIC segmentation with {n_segments} segments...") + + # Apply SLIC superpixel segmentation + segments = slic( + image=image, + n_segments=n_segments, + compactness=25.0, # Balance between color similarity and spatial proximity + max_num_iter=25, # Maximum number of k-means iterations + sigma=1.0, # Gaussian smoothing kernel width + convert2lab=True, # Convert to LAB color space + enforce_connectivity=True, # Ensure connected superpixels + start_label=1, # Start labeling from 1 (0 reserved for background) + mask=foreground_mask > 0, # Only segment foreground regions + channel_axis=-1, # Color channels are in the last dimension + ) + + unique_segments = len(np.unique(segments)) + logging.info(f"Generated {unique_segments} superpixels") + + return segments + + def _find_background_neighbors(self, segments: np.ndarray) -> np.ndarray: + """ + Find superpixels that are neighbors of the background (segment 0). + """ + background_positions = np.argwhere(segments == 0) + neighbors_of_background = set() + + height, width = segments.shape + + for row, col in background_positions: + # Check 4-connected neighbors + neighbor_coords = [ + (row - 1, col), + (row + 1, col), + (row, col - 1), + (row, col + 1), + ] + + for r, c in neighbor_coords: + if 0 <= r < height and 0 <= c < width: + neighbor_id = segments[r, c] + if neighbor_id != 0: # Not background + neighbors_of_background.add(neighbor_id) + + return neighbors_of_background + + def _extract_geometric_features(self, segment_mask: np.ndarray) -> Dict[str, Union[float, int]]: + """ + Extract geometric features from a segment mask. + """ + # Find contours + contours, _ = cv2.findContours( + segment_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + + if len(contours) == 0: + # Default values for empty contours + return { + "area": np.sum(segment_mask), + "perimeter": 0, + "circularity": 0, + "aspect_ratio": 0, + "eccentricity": 0, + "solidity": 0, + "centroid_x": 0, + "centroid_y": 0, + } + + # Use the largest contour + largest_contour = max(contours, key=cv2.contourArea) + + # Basic geometric properties + area = cv2.contourArea(largest_contour) + perimeter = cv2.arcLength(largest_contour, closed=True) + + # Circularity (4π*area / perimeter²) + circularity = 4 * np.pi * area / (perimeter**2) if perimeter > 0 else 0 + + # Bounding rectangle and aspect ratio + x, y, w, h = cv2.boundingRect(largest_contour) + aspect_ratio = w / h if h > 0 else 0 + centroid_x, centroid_y = x + w / 2, y + h / 2 + + # Eccentricity from fitted ellipse + eccentricity = 1.0 # Default value + if len(largest_contour) >= 5: # Need at least 5 points for ellipse fitting + try: + (center, axes, orientation) = cv2.fitEllipse(largest_contour) + major_axis = max(axes) + minor_axis = min(axes) + if major_axis > 0: + eccentricity = np.sqrt(1 - (minor_axis**2 / major_axis**2)) + except cv2.error: + eccentricity = 1.0 + + # Solidity (area / convex hull area) + hull = cv2.convexHull(largest_contour) + hull_area = cv2.contourArea(hull) + solidity = area / hull_area if hull_area > 0 else 0 + + return { + "area": area, + "perimeter": perimeter, + "circularity": circularity, + "eccentricity": eccentricity, + "solidity": solidity, + "aspect_ratio": aspect_ratio, + "centroid_x": centroid_x, + "centroid_y": centroid_y, + } + + def _compute_local_contrast( + self, + gray_image: np.ndarray, + segment_mask: np.ndarray, + dilate_size: Optional[int] = 15, + ) -> float: + """ + Compute local contrast of a superpixel relative to its surroundings. + """ + # Calculate mean intensity inside the superpixel + mean_inside = cv2.mean(gray_image, mask=segment_mask)[0] + + # Create dilated mask to define surrounding region + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilate_size, dilate_size)) + dilated_mask = cv2.dilate(segment_mask, kernel) + + # Get surrounding ring by subtracting original mask from dilated mask + surrounding_mask = cv2.subtract(dilated_mask, segment_mask) + + # Calculate mean intensity of surrounding region + mean_surrounding = cv2.mean(gray_image, mask=surrounding_mask)[0] + + # Return contrast as difference between inside and surrounding intensities + return mean_inside - mean_surrounding + + def _extract_superpixel_features( + self, image: np.ndarray, segments: np.ndarray, foreground_mask: np.ndarray + ) -> np.ndarray: + """ + Extract features from each superpixel for electrode classification. + """ + logging.info("Extracting superpixel features...") + + features_list = [] + segment_ids = np.unique(segments) + + # Convert to LAB color space for color features + lab_image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) + l_channel, a_channel, b_channel = cv2.split(lab_image) + + # Compute FRST (Fast Radial Symmetry Transform) response + frst = FRST() + frst_response = frst.run( + lab_image[:, :, 0], radii=list(range(1, 25, 2)), alpha=1.0, beta=0.01 + ) + + # Find segments neighboring background (segment 0) + background_neighbors = self._find_background_neighbors(segments) + + # Convert to grayscale for contrast computation + gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + + # Process each superpixel + for segment_id in segment_ids: + # Skip background and edge segments + if segment_id == 0 or segment_id in background_neighbors: + continue + + segment_mask = (segments == segment_id).astype(np.uint8) + + # Filter out small segments or those mostly in background + segment_size = np.sum(segment_mask) + foreground_overlap = np.sum(segment_mask & foreground_mask) + + if segment_size < 5 or foreground_overlap < segment_size * 0.5: + continue + + # Extract LAB color features + l_mean = np.mean(l_channel[segment_mask > 0]) + l_std = np.std(l_channel[segment_mask > 0]) + a_mean = np.mean(a_channel[segment_mask > 0]) + b_mean = np.mean(b_channel[segment_mask > 0]) + + # Extract geometric features from contours + geometric_features = self._extract_geometric_features(segment_mask) + + # Compute local contrast + contrast = self._compute_local_contrast(gray_image, segment_mask, dilate_size=15) + + # Compute mean FRST response + frst_mean = cv2.mean(frst_response, mask=segment_mask)[0] + + # Combine all features + features = [ + segment_id, # 0: Segment ID + l_mean, # 1: L channel mean + l_std, # 2: L channel standard deviation + a_mean, # 3: A channel mean + b_mean, # 4: B channel mean + geometric_features["area"], # 5: Area + geometric_features["perimeter"], # 6: Perimeter + geometric_features["circularity"], # 7: Circularity + geometric_features["eccentricity"], # 8: Eccentricity + geometric_features["solidity"], # 9: Solidity + geometric_features["aspect_ratio"], # 10: Aspect ratio + contrast, # 11: Local contrast + frst_mean, # 12: FRST response mean + geometric_features["centroid_x"], # 13: Centroid X + geometric_features["centroid_y"], # 14: Centroid Y + None, # 15: Confidence (computed later) + ] + + features_list.append(features) + + logging.info(f"Extracted features from {len(features_list)} superpixels") + return np.array(features_list) + + def _cluster_superpixels(self, features: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Cluster superpixels using Gaussian Mixture Model to identify electrode candidates. + """ + logging.info("Clustering superpixels with GMM...") + + # Extract relevant features for clustering (contrast and FRST response) + feature_indices = [11, 12] # contrast, frst_mean + clustering_features = features[:, feature_indices] + + # Apply feature weighting + feature_weights = np.array([2.0, 1.0]) # Higher weight for contrast + weighted_features = clustering_features * feature_weights + + # Standardize features + scaler = StandardScaler() + normalized_features = scaler.fit_transform(weighted_features) + + # Fit Gaussian Mixture Model + gmm = GaussianMixture( + n_components=2, # Two clusters: electrodes vs non-electrodes + covariance_type="full", # Full covariance matrices + max_iter=250, # Maximum iterations for convergence + random_state=42, # For reproducibility + ) + + cluster_labels = gmm.fit_predict(normalized_features) + + # Identify electrode cluster (assume it's the smaller cluster) + unique_labels, label_counts = np.unique(cluster_labels, return_counts=True) + electrode_cluster_id = unique_labels[np.argmin(label_counts)] + + # Extract features of electrode cluster + electrode_mask = cluster_labels == electrode_cluster_id + electrode_cluster_features = features[electrode_mask] + + logging.info( + f"Identified {len(electrode_cluster_features)} electrode candidates from clustering" + ) + + return cluster_labels, electrode_cluster_features + + def _filter_electrode_candidates( + self, + image: np.ndarray, + foreground_mask: np.ndarray, + view_type: ViewType, + cluster_features: np.ndarray, + slic_segments: np.ndarray, + ) -> np.ndarray: + """ + Filter electrode candidates based on shape, appearance, and segment location. + """ + logging.info("Filtering electrode candidates based on shape criteria...") + + valid_candidates = [] + + # Generate basic electrode mask + basic_detector = BasicElectrodeDetector() + mask = basic_detector.get_electrode_mask( + image, + foreground_mask, + view_type, + [DetectionMethod.TRADITIONAL, DetectionMethod.FRST], + ) + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, (25, 25)) + + for feature_row in cluster_features: + # Extract relevant features + segment_id = feature_row[0] + circularity = feature_row[7] + eccentricity = feature_row[8] + solidity = feature_row[9] + aspect_ratio = feature_row[10] + contrast = feature_row[11] + frst_mean = feature_row[12] + + # Shape-based filtering + shape_criteria_met = ( + circularity >= 0.75 + and eccentricity < 0.75 + and solidity > 0.9 + and aspect_ratio < 1.5 + ) + + if not shape_criteria_met: + continue + + # Check if SLIC segment lies inside the electrode mask + segment_mask = (slic_segments == segment_id).astype(np.uint8) + overlap = cv2.bitwise_and(segment_mask, mask) + overlap_ratio = np.sum(overlap) / max(1, np.sum(segment_mask)) # Avoid div by zero + + if overlap_ratio < 0.5: # Require at least 50% of the segment inside the mask + continue + + # Multi-factor confidence score + shape_score = (circularity + solidity) / 2.0 + confidence = np.abs(contrast) * 0.5 + shape_score * 0.3 + np.abs(frst_mean) * 0.2 + + feature_row[15] = confidence + valid_candidates.append(feature_row) + + return np.array(valid_candidates) + + def _apply_non_maximum_suppression( + self, candidates: np.ndarray, view_type: ViewType + ) -> np.ndarray: + """ + Apply non-maximum suppression to remove closely spaced electrode detections. + """ + logging.info("Applying non-maximum suppression...") + + # Sort candidates by confidence (highest first) + sorted_candidates = sorted(candidates, key=lambda x: x[15], reverse=True) + + final_electrodes = [] + + for candidate in sorted_candidates: + candidate_x = float(candidate[13]) # centroid_x + candidate_y = float(candidate[14]) # centroid_y + + # Check if candidate is too close to already selected electrodes + too_close = False + for selected_electrode in final_electrodes: + selected_x = float(selected_electrode[13]) + selected_y = float(selected_electrode[14]) + + distance = np.sqrt( + (candidate_x - selected_x) ** 2 + (candidate_y - selected_y) ** 2 + ) + + min_dist = self.min_distance or view_type.electrode_cfg.min_distance or 50 + if distance < min_dist: + too_close = True + break + + # Add candidate if it's not too close to existing selections + if not too_close: + final_electrodes.append(candidate) + + logging.info(f"Selected {len(final_electrodes)} final electrodes after NMS") + return np.array(final_electrodes) + + def _create_output_visualization( + self, image: np.ndarray, segments: np.ndarray, final_electrodes: np.ndarray + ) -> np.ndarray: + """ + Create visualization image showing detected electrodes. + """ + if len(final_electrodes) == 0: + return image + + # Extract segment IDs of detected electrodes + electrode_segment_ids = final_electrodes[:, 0].astype(int) + + # Create mask for electrode segments + electrode_mask = np.isin(segments, electrode_segment_ids) + electrode_segments = np.where(electrode_mask, segments, 0) + + # Create visualization with marked boundaries + output_image = mark_boundaries( + image, + electrode_segments, + color=(1, 0, 0), + mode="thick", + ) + + return output_image + + def _segments_to_circles( + self, final_electrodes: np.ndarray + ) -> List[Tuple[float, float, float]]: + """ + Convert detected electrodes to circle representations (x, y, r) using equivalent radius. + """ + if len(final_electrodes) == 0: + return [] + + logging.info(f"Converting {len(final_electrodes)} electrodes to circles...") + circles = [] + + for electrode in final_electrodes: + centroid_x = float(electrode[13]) + centroid_y = float(electrode[14]) + area = float(electrode[5]) + + # Calculate equivalent radius: r = sqrt(area / π) + radius = np.sqrt(area / np.pi) * 2 if area > 0 else 25 + + circles.append((int(centroid_x), int(centroid_y), int(radius))) + + return np.array([circles]) + + def detect( + self, image: np.ndarray, foreground_mask: np.ndarray, view_type: ViewType + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Detect electrodes using SLIC and GMM. + """ + # Preprocess image + preprocessed = self._preprocess_image(image, foreground_mask) + + # Apply SLIC segmentation + segments = self._apply_slic_segmentation(preprocessed, foreground_mask, view_type) + + # Extract features from superpixels + features = self._extract_superpixel_features(preprocessed, segments, foreground_mask) + + if len(features) == 0: + logging.warning("No valid superpixels found for feature extraction") + return np.array([]), image + + # Cluster superpixels to identify electrode candidates + cluster_labels, electrode_candidates = self._cluster_superpixels(features) + + # Filter candidates based on shape criteria + filtered_candidates = self._filter_electrode_candidates( + image, foreground_mask, view_type, electrode_candidates, segments + ) + + if len(filtered_candidates) == 0: + logging.warning("No electrode candidates passed shape filtering") + return np.array([]), image + + # Apply non-maximum suppression + final_electrodes = self._apply_non_maximum_suppression(filtered_candidates, view_type) + + # Create output visualization + output_segments = self._create_output_visualization(image, segments, final_electrodes) + + # Convert segment to circle + circles = self._segments_to_circles(final_electrodes) + output_image = BasicElectrodeDetector.draw_circles(image, circles) + + logging.info(f"Successfully detected {len(final_electrodes)} electrodes") + return circles, output_image, output_segments diff --git a/src/processing_models/electrode/electrode_mapper.py b/src/processing_models/electrode/electrode_mapper.py new file mode 100644 index 0000000..3214354 --- /dev/null +++ b/src/processing_models/electrode/electrode_mapper.py @@ -0,0 +1,395 @@ +import logging +import numpy as np +import pandas as pd + +from collections import defaultdict +from vedo import Mesh, Plotter, Spheres +from typing import Dict, List, Tuple, Any, Optional + +from .view_type import ViewType +from .detection_method import DetectionMethod +from .electrode_merger import ElectrodeMerger + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +class ElectrodeMapper: + + def __init__( + self, + mesh: Mesh, + mesh_cap: Mesh, + fiducials: pd.DataFrame, + center: Optional[np.ndarray] = None, + ): + """ + Initialize mapper that maps 2D detected eletrodes back to 3D mesh coordinates using ray casting. + """ + self.mesh = mesh.clone() + self.mesh_cap = mesh_cap.clone() + self.fiducials = fiducials.copy() + self.electrodes = None + + # Default center if none provided + self.center = center.copy() if center is not None else np.array([0, 0, 0]) + + # Setup standard orthogonal camera views + self._setup_cameras() + + @staticmethod + def _setup_camera(center: np.ndarray, camera_distance: Dict[str, float]) -> Dict[str, Any]: + """ + Configure standard orthogonal camera views: front, back, right, left, top, bottom. + """ + dx, dy, dz = ( + camera_distance["x"], + camera_distance["y"], + camera_distance["z"], + ) + cx, cy, cz = center + + return { + "center": center, + "distance": camera_distance, + "views": { + "front": { + "pos": [cx, cy, cz + dz], + "up": [0, 1, 0], + "name": "front", + "description": "Front view (+Z)", + }, + "back": { + "pos": [cx, cy, cz - dz], + "up": [0, 1, 0], + "name": "back", + "description": "Back view (-Z)", + }, + "top": { + "pos": [cx, cy + dy, cz], + "up": [0, 0, -1], + "name": "top", + "description": "Top view (+Y)", + }, + "bottom": { + "pos": [cx, cy - dy, cz], + "up": [0, 0, 1], + "name": "bottom", + "description": "Bottom view (-Y)", + }, + "right": { + "pos": [cx - dx, cy, cz], + "up": [0, 1, 0], + "name": "right", + "description": "Right view (+X)", + }, + "left": { + "pos": [cx + dx, cy, cz], + "up": [0, 1, 0], + "name": "left", + "description": "Left view (-X)", + }, + "front_top": { + "pos": [cx, (cy + dy) * 0.75, (cz + dz) * 0.75], + "up": [0, 1, 0], + "name": "front_top", + "description": "Front-top view (+Z, +Y)", + }, + "front_bottom": { + "pos": [cx, (cy - dy) * 0.75, (cz + dz) * 0.75], + "up": [0, 1, 0], + "name": "front_bottom", + "description": "Front-bottom view (+Z, -Y)", + }, + "back_top": { + "pos": [cx, (cy + dy) * 0.75, (cz - dz) * 0.75], + "up": [0, 1, 0], + "name": "back_top", + "description": "Back-top view (-Z, +Y)", + }, + "back_bottom": { + "pos": [cx, (cy - dy) * 0.75, (cz - dz) * 0.75], + "up": [0, 1, 0], + "name": "back_bottom", + "description": "Back-bottom view (-Z, -Y)", + }, + "front_right": { + "pos": [(cx - dx) * 0.75, cy, (cz + dz) * 0.75], + "up": [0, 1, 0], + "name": "front_right", + "description": "Front-right view (+X, +Z)", + }, + "front_left": { + "pos": [(cx + dx) * 0.75, cy, (cz + dz) * 0.75], + "up": [0, 1, 0], + "name": "front_left", + "description": "Front-left view (-X, +Z)", + }, + "back_right": { + "pos": [(cx - dx) * 0.75, cy, (cz - dz) * 0.75], + "up": [0, 1, 0], + "name": "back_right", + "description": "Back-right view (+X, -Z)", + }, + "back_left": { + "pos": [(cx + dx) * 0.75, cy, (cz - dz) * 0.75], + "up": [0, 1, 0], + "name": "back_left", + "description": "Back-left view (-X, -Z)", + }, + "top_right": { + "pos": [(cx - dx) * 0.75, (cy + dy) * 0.75, cz], + "up": [0, 1, 0], + "name": "top_right", + "description": "Top-right view (+X, +Y)", + }, + "top_left": { + "pos": [(cx + dx) * 0.75, (cy + dy) * 0.75, cz], + "up": [0, 1, 0], + "name": "top_left", + "description": "Top-left view (-X, +Y)", + }, + "bottom_right": { + "pos": [(cx - dx) * 0.75, (cy - dy) * 0.75, cz], + "up": [0, 1, 0], + "name": "top_right", + "description": "Bottom-right view (+X, -Y)", + }, + "bottom_left": { + "pos": [(cx + dx) * 0.75, (cy - dy) * 0.75, cz], + "up": [0, 1, 0], + "name": "top_left", + "description": "Bottom-left view (-X, -Y)", + }, + }, + } + + def _setup_cameras(self) -> None: + """ + Setup camera configurations for the mesh. + """ + # Estimate a reasonable camera distance from mesh bounds + center = self.center.copy() + bounds = self.mesh_cap.bounds() + camera_distance = { + "x": (bounds[1] - bounds[0]) * 2.5, + "y": (bounds[3] - bounds[2]) * 1.75, + "z": (bounds[5] - bounds[4]) * 2.5, + } + + # Define a custom camera position based on fiducials (for cap views) + if self.fiducials is not None: + required_labels = ["NAS", "LPA", "RPA"] + + if set(required_labels).issubset(self.fiducials.index): + # Extract fiducial coordinates + fiducials = { + label: self.fiducials.loc[label].to_numpy(dtype=float) + for label in required_labels + } + + nas, lpa, rpa = fiducials["NAS"], fiducials["LPA"], fiducials["RPA"] + plane_center = np.mean([nas, lpa, rpa], axis=0) + + # Update center.y to position the camera above the cropped mesh + center_y = (bounds[3] - plane_center[2]) / 2.0 + center[1] = center_y + + # Set camera distance in y direction (scaled by bounding box extent) + bbox_y_extent = bounds[3] - bounds[2] + camera_distance["y"] = bbox_y_extent * 2.75 + + # Setup camera positions + self.camera_config = self._setup_camera(center, camera_distance) + + def _pixel_to_ray( + self, + pixel_coords: Tuple[int, int], + view_type: ViewType, + image_size: Tuple[int, int], + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Convert 2D pixel coordinates to 3D world ray (origin and direction). + """ + if self.camera_config is None: + raise ValueError("Camera configuration not set") + + # Calculate camera coordinate system + camera_pos = np.array(self.camera_config["views"][view_type.label]["pos"]) + focal_point = np.array(self.camera_config["center"]) + view_up = np.array(self.camera_config["views"][view_type.label]["up"]) + + # Convert pixel coordinates to normalized device coordinates (-1 to 1) + width, height = image_size[:2] + x_pixel, y_pixel = pixel_coords + x_ndc = (2.0 * x_pixel / width) - 1.0 + y_ndc = 1.0 - (2.0 * y_pixel / height) # Flip Y axis + + # Forward vector (from camera to focal point) + forward = focal_point - camera_pos + forward = forward / np.linalg.norm(forward) + + # Right vector + right = np.cross(forward, view_up) + right = right / np.linalg.norm(right) + + # Corrected up vector + up = np.cross(right, forward) + up = up / np.linalg.norm(up) + + # Calculate ray direction + aspect_ratio = width / height + fov_rad = np.radians(30.0) # Default view angle + + # Calculate the ray direction in camera space + tan_half_fov = np.tan(fov_rad / 2.0) + ray_x = x_ndc * tan_half_fov * aspect_ratio + ray_y = y_ndc * tan_half_fov + + # Transform to world coordinates + ray_origin = camera_pos + ray_direction = forward + ray_x * right + ray_y * up + ray_direction = ray_direction / np.linalg.norm(ray_direction) + + return ray_origin, ray_direction + + def _ray_mesh_intersection( + self, ray_origin: np.ndarray, ray_direction: np.ndarray + ) -> Optional[np.ndarray]: + """ + Find intersection point between ray and mesh surface. + """ + # Perform ray-mesh intersection using Vedo + t = 1000 # Extend ray far enough + ray_end = ray_origin + ray_direction * t + + # Use Vedo's intersectWithLine method + intersection_points = self.mesh_cap.intersect_with_line(ray_origin, ray_end) + + if len(intersection_points) > 0: + # Find closest intersection point to camera + points_array = np.array([p for p in intersection_points]) + distances = np.linalg.norm(points_array - ray_origin, axis=1) + closest_idx = np.argmin(distances) + return points_array[closest_idx] + + return None + + def _flatten_electrode_map( + self, + mapped_electrodes: Dict[ + ViewType, Dict[DetectionMethod, List[Tuple[float, float, float, str]]] + ], + ) -> List[Tuple[ViewType, DetectionMethod, float, float, float, str]]: + """ + Flatten a nested dictionary of 3D electrode coordinates into a list of tuples. + """ + flatten_electrodes = [] + for view, method_dict in mapped_electrodes.items(): + for method, coords_list in method_dict.items(): + for x, y, z, label in coords_list: + flatten_electrodes.append((view, method, x, y, z, label)) + return np.array(flatten_electrodes) + + def map_electrodes_to_3d( + self, + detected_electrodes: Dict[ViewType, Dict[DetectionMethod, List[Tuple[int, int, int, str]]]], + metadata: Dict[ViewType, Dict[str, Any]], + ) -> List[Tuple[ViewType, DetectionMethod, float, float, float, str]]: + """ + Map detected 2D electrodes to 3D mesh coordinates. + """ + mapped_electrodes = defaultdict(dict) + + for view_type, methods in detected_electrodes.items(): + if view_type.label not in self.camera_config["views"]: + logging.warning(f"Unknown view {view_type}") + continue + logging.info(f"Mapping electrodes from {view_type} view") + + # Get image size from metadata + shape = metadata.get(view_type, {}).get("shape") + image_size = (1024, 1024) + if shape is not None and len(shape) >= 2: + image_size = tuple(shape[:2]) + + for method, electrodes in methods.items(): + + view_3d_coords = [] + for x, y, _, label in electrodes: + # Convert pixel to world ray + ray_origin, ray_direction = self._pixel_to_ray((x, y), view_type, image_size) + + # Find intersection with mesh + intersection = self._ray_mesh_intersection(ray_origin, ray_direction) + + if intersection is not None: + intersection = np.append(np.array(intersection, dtype=object), label) + view_3d_coords.append(intersection) + else: + logging.warning( + f"No intersection found for electrode {view_type}.{method} ({x}, {y})" + ) + + mapped_electrodes[view_type][method] = view_3d_coords + + # Flatten map + flatten_electrodes = self._flatten_electrode_map(mapped_electrodes) + + # Cluster electrodes + merger = ElectrodeMerger(15.0) + merged_electrodes = merger.cluster_electrodes(flatten_electrodes) + self.electrodes = merged_electrodes.copy() + + return merged_electrodes + + def visualize_results( + self, + show_fiducials: bool = True, + ) -> None: + """ + Visualize the 3D mesh with mapped electrodes. + """ + if self.electrodes is None: + raise ValueError("Electrodes not merged") + + plotter = Plotter(title="EEG Electrode Mapping Results") + plotter.add(self.mesh) + + # Add electrodes + spheres = Spheres(self.electrodes[:, (2, 3, 4)], r=3.5, c="gray") + plotter.add(spheres) + + # Add fiducial points + if show_fiducials and self.fiducials is not None: + fiducial_spheres = Spheres(self.fiducials, r=2.5, c="red") + plotter.add(fiducial_spheres) + + plotter.show() + + def save( + self, + file_path: str, + electrodes: Optional[ + List[Tuple[ViewType, DetectionMethod, float, float, float, str]] + ] = None, + ) -> None: + """ + Save detected EEG electrodes to a CSV file. + Format: ViewType,DetectionMethod,x,y,z + """ + if electrodes is None: + if self.electrodes is None: + raise ValueError("Electrodes not merged") + electrodes = self.electrodes + + try: + with open(file_path, "w", encoding="utf-8") as f: + for fid in electrodes: + if fid is not None: + f.write( + f"{fid[0].name},{fid[1].name},{fid[2]},{fid[3]},{fid[4]},{fid[5]}\n" + ) + logging.info(f"Electrodes saved to {file_path}") + except OSError as e: + logging.error(f"Failed to save electrodes: {e}") diff --git a/src/processing_models/electrode/electrode_merger.py b/src/processing_models/electrode/electrode_merger.py new file mode 100644 index 0000000..217dd7e --- /dev/null +++ b/src/processing_models/electrode/electrode_merger.py @@ -0,0 +1,176 @@ +import logging +import numpy as np + +from typing import List, Tuple +from scipy.spatial.distance import pdist, squareform + +from .view_type import ViewType +from .detection_method import DetectionMethod + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +class ElectrodeMerger: + + def __init__(self, distance_threshold: float = 15.0) -> None: + """ + Initialize electrode clusterer to process and merge electrodes from multiple camera views. + """ + self.distance_threshold = distance_threshold + + def _dfs_cluster_search( + self, + node: int, + cluster: List[int], + visited: List[bool], + distance_matrix: np.ndarray, + threshold: float, + ) -> None: + """ + Depth-first search to find connected components in a cluster. + """ + visited[node] = True + cluster.append(node) + + # Explore all neighbors within threshold distance + for neighbor in range(len(distance_matrix)): + if not visited[neighbor] and distance_matrix[node][neighbor] <= threshold: + self._dfs_cluster_search(neighbor, cluster, visited, distance_matrix, threshold) + + def _find_clusters(self, distance_matrix: np.ndarray, threshold: float) -> List[List[int]]: + """ + Find clusters using distance threshold with connected components approach. + """ + if len(distance_matrix.shape) != 2 or distance_matrix.shape[0] != distance_matrix.shape[1]: + raise ValueError("Distance matrix must be square") + + if threshold < 0: + raise ValueError("Threshold must be non-negative") + + n = len(distance_matrix) + visited = [False] * n + clusters = [] + + # Find all connected components (clusters) + for i in range(n): + if not visited[i]: + cluster = [] + self._dfs_cluster_search(i, cluster, visited, distance_matrix, threshold) + clusters.append(cluster) + + return clusters + + def _find_geometric_median( + self, points: np.ndarray, max_iterations: int = 100, tolerance: float = 1e-6 + ) -> np.ndarray: + """ + Find geometric median using Weiszfeld's algorithm. + """ + points = np.array(points, dtype=np.float64) + + if len(points) == 1: + return points[0] + + if len(points) == 2: + return np.mean(points, axis=0) + + # Start with centroid as initial guess + median = np.mean(points, axis=0) + + for _ in range(max_iterations): + # Calculate distances from current median to all points + distances = np.linalg.norm(points - median, axis=1) + + # Handle points that coincide with current median (avoid division by zero) + non_zero_distances = distances > tolerance + + if not np.any(non_zero_distances): + # All points are at the current median + break + + # Calculate weights (inverse of distances) + weights = np.zeros(len(points)) + weights[non_zero_distances] = 1.0 / distances[non_zero_distances] + + # Skip points that are too close to avoid numerical issues + if np.sum(weights) == 0: + break + + # Calculate new median as weighted average + new_median = np.sum(weights[:, np.newaxis] * points, axis=0) / np.sum(weights) + + # Check for convergence + if np.linalg.norm(new_median - median) < tolerance: + break + + median = new_median + + return median + + def _compute_cluster_centroid( + self, + cluster_electrodes: List[Tuple[ViewType, DetectionMethod, float, float, float, str]], + ) -> Tuple[ViewType, DetectionMethod, float, float, float, str]: + """ + Compute centroid for a cluster with marker priority. + """ + if len(cluster_electrodes) == 1: + return cluster_electrodes[0] + + # Check if any electrode in cluster uses marker method + marker_electrodes = [ + electrode for electrode in cluster_electrodes if electrode[1] == DetectionMethod.MARKER + ] + + if len(marker_electrodes) > 0: + # Find closest marker + cluster_electrodes = marker_electrodes + + # Check if any electrode in cluster is labeled + labeled_electrodes = [ + electrode + for electrode in cluster_electrodes + if electrode[5] != "" and electrode[5] is not None + ] + + if len(labeled_electrodes) > 0: + # Find closest labeled electrode + cluster_electrodes = labeled_electrodes + + # Find geometric median of all points and return closest electrode + all_coords = np.array([(x, y, z) for _, _, x, y, z, _ in cluster_electrodes]) + geometric_median = self._find_geometric_median(all_coords) + + # Find closest actual electrode to geometric median + distances = np.linalg.norm(all_coords - geometric_median, axis=1) + closest_idx = np.argmin(distances) + + return cluster_electrodes[closest_idx] + + def cluster_electrodes( + self, + electrodes: List[Tuple[ViewType, DetectionMethod, float, float, float, str]], + ) -> List[Tuple[ViewType, DetectionMethod, float, float, float, str]]: + """ + Cluster nearby electrodes and compute centroids with labeled MARKER method priority. + """ + # Extract coordinates and create metadata mapping + coordinates = np.array([[x, y, z] for _, _, x, y, z, _ in electrodes]) + + # Compute pairwise distances + distances = squareform(pdist(coordinates, metric="euclidean")) + + # Find clusters using distance threshold + clusters = self._find_clusters(distances, self.distance_threshold) + logging.info(f"Detected {len(clusters)} electrode clusters") + + # Compute centroids for each cluster + centroids = [] + for cluster_indices in clusters: + cluster_electrodes = [electrodes[i] for i in cluster_indices] + centroid = self._compute_cluster_centroid(cluster_electrodes.copy()) + centroids.append(centroid) + logging.info(f"Electrodes merged, reduced count from {len(electrodes)} to {len(centroids)}") + + return np.array(centroids) diff --git a/src/processing_models/electrode/frst.py b/src/processing_models/electrode/frst.py new file mode 100644 index 0000000..0bfd337 --- /dev/null +++ b/src/processing_models/electrode/frst.py @@ -0,0 +1,89 @@ +import cv2 +import numpy as np +from typing import List + + +class FRST: + """ + Fast Radial Symmetry Transform (FRST) for circular pattern detection. + + Reference: https://ieeexplore.ieee.org/document/1217601 + """ + + def run( + self, + image: np.ndarray, + radii: List[int], + alpha: float = 2.0, + beta: float = 0.1, + ) -> np.ndarray: + """ + Run the FRST on an input grayscale image. + """ + rows, cols = image.shape + + # --- Gradient computation --- + grad_x = cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize=3) + grad_y = cv2.Sobel(image, cv2.CV_64F, 0, 1, ksize=3) + grad_mag = np.hypot(grad_x, grad_y) + grad_dir = np.arctan2(grad_y, grad_x) + + # --- Threshold gradients --- + grad_mask = grad_mag > (beta * np.max(grad_mag)) + y_idx, x_idx = np.nonzero(grad_mask) # indices of strong gradients + + # --- Accumulator --- + S = np.zeros((rows, cols), dtype=np.float64) + + for radius in radii: + # Projection images + O_p = np.zeros((rows, cols), dtype=np.float64) + O_n = np.zeros((rows, cols), dtype=np.float64) + M_p = np.zeros((rows, cols), dtype=np.float64) + M_n = np.zeros((rows, cols), dtype=np.float64) + + # Precompute cosine/sine for all valid points + cos_dir = np.cos(grad_dir[y_idx, x_idx]) + sin_dir = np.sin(grad_dir[y_idx, x_idx]) + + # Positive offsets (gradient points outward) + x_p = (x_idx - radius * cos_dir).astype(int) + y_p = (y_idx - radius * sin_dir).astype(int) + + # Negative offsets (gradient points inward) + x_n = (x_idx + radius * cos_dir).astype(int) + y_n = (y_idx + radius * sin_dir).astype(int) + + # Clip to image boundaries + valid_p = (0 <= x_p) & (x_p < cols) & (0 <= y_p) & (y_p < rows) + valid_n = (0 <= x_n) & (x_n < cols) & (0 <= y_n) & (y_n < rows) + + # Accumulate contributions + np.add.at(O_p, (y_p[valid_p], x_p[valid_p]), 1) + np.add.at( + M_p, + (y_p[valid_p], x_p[valid_p]), + grad_mag[y_idx[valid_p], x_idx[valid_p]], + ) + + np.add.at(O_n, (y_n[valid_n], x_n[valid_n]), 1) + np.add.at( + M_n, + (y_n[valid_n], x_n[valid_n]), + grad_mag[y_idx[valid_n], x_idx[valid_n]], + ) + + # Symmetry contribution + F_p = np.divide(M_p, O_p, out=np.zeros_like(M_p), where=O_p > 0) + F_n = np.divide(M_n, O_n, out=np.zeros_like(M_n), where=O_n > 0) + F = F_p - F_n + + # Gaussian smoothing (scale by radius) + sigma = 0.25 * radius + F_smooth = cv2.GaussianBlur(F, (0, 0), sigma) + + # Weighted accumulation + S += np.power(F_smooth, alpha) + + # Normalize by number of radii + return S / len(radii) diff --git a/src/processing_models/electrode/marker_detector.py b/src/processing_models/electrode/marker_detector.py new file mode 100644 index 0000000..7b2fa4f --- /dev/null +++ b/src/processing_models/electrode/marker_detector.py @@ -0,0 +1,306 @@ +import cv2 +import logging +import numpy as np + +from typing import List, Optional, Tuple + +from .frst import FRST +from .view_type import ViewType +from .util import ColorEnhancementUtil +from .detection_method import DetectionMethod + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +class MarkerDetector: + def __init__( + self, + min_area: Optional[int] = None, + min_distance: Optional[int] = None, + min_radius: Optional[int] = None, + max_radius: Optional[int] = None, + ) -> None: + """ + Initialize green marker detector. + """ + self.min_area = min_area + self.min_distance = min_distance + self.min_radius = min_radius + self.max_radius = max_radius + + def _preprocess_image(self, image: np.ndarray) -> np.ndarray: + """ + Enhance green regions and convert to a normalized grayscale for processing. + """ + # Enhance green regions in HSV color space (overboost saturation and value) + enhanced = ColorEnhancementUtil.enhance_green_in_hsv(image, 60.0, 25.0, 2.5, 2.5) + + # Convert to LAB and extract a-channel + lab_image = cv2.cvtColor(enhanced, cv2.COLOR_RGB2LAB) + _, a_channel, _ = cv2.split(lab_image) + + # Smooth the a-channel to reduce noise + a_channel_smooth = cv2.GaussianBlur(a_channel.astype(np.float32), (3, 3), 1.0) + + # Convert to uint8 (for thresholding) + a_uint8 = cv2.normalize(a_channel_smooth, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) + return a_uint8 + + def _apply_morphological_operations(self, binary_mask: np.ndarray) -> np.ndarray: + """ + Clean up binary mask using opening and closing operations. + """ + kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) + kernel_large = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (30, 30)) + + # Remove small noise with opening + opened = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel_small) + + # Fill small holes with closing + closed = cv2.morphologyEx(opened, cv2.MORPH_CLOSE, kernel_large) + + # Make structures a bit smaller again + reopened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel_small) + + # Refine shapes with erosion-dilation + eroded = cv2.erode(reopened, kernel_small, iterations=1) + refined = cv2.dilate(eroded, kernel_large, iterations=1) + + return refined + + def _filter_contours( + self, morph_mask: np.ndarray, view_type: ViewType, use_frst: bool = False + ) -> np.ndarray: + """ + Filter contours by area and circularity, applying view-specific masks. + """ + contours, _ = cv2.findContours(morph_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + filtered = np.zeros_like(morph_mask) + + for contour in contours: + # View-specific exclusions + if hasattr(view_type, "name"): + x, y, w, h = cv2.boundingRect(contour) + if view_type.name == "BACK" and y > 700: + continue + if view_type.name == "BACK_TOP" and y > 850: + continue + if view_type.name == "BACK_RIGHT" and y > 750 and x < 612: + continue + + area = cv2.contourArea(contour) + min_area = self.min_area or view_type.marker_cfg.min_area or 750 + if area < min_area: + continue + + perimeter = cv2.arcLength(contour, True) + if perimeter == 0: + continue + + circularity = 4 * np.pi * area / (perimeter * perimeter) + threshold = 0.25 if use_frst else 0.3 + if circularity <= threshold: + continue + + cv2.drawContours(filtered, [contour], -1, 255, -1) + + return filtered + + def _apply_dog_and_hough( + self, + image: np.ndarray, + a_uint8: np.ndarray, + filtered_mask: np.ndarray, + view_type: ViewType, + use_frst: bool = False, + ) -> Optional[np.ndarray]: + """ + Detect circular markers using Difference of Gaussians + Hough Transform. + """ + # Apply DoG filtering + masked = cv2.bitwise_and(a_uint8, a_uint8, mask=filtered_mask) + blur_small = cv2.GaussianBlur(masked.astype(np.float32), (0, 0), 2.5) + blur_large = cv2.GaussianBlur(masked.astype(np.float32), (0, 0), 5.0) + dog = blur_small - blur_large + + dog_normalized = cv2.normalize(dog, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) + dog_filtered = cv2.medianBlur(dog_normalized, 5) + + # Hough Circle parameters + param2 = 25 if use_frst else 30 + min_dist = self.min_distance or view_type.marker_cfg.min_distance or 75 + min_r = self.min_radius or view_type.marker_cfg.min_radius or 25 + max_r = self.max_radius or view_type.marker_cfg.max_radius or 40 + + circles = cv2.HoughCircles( + dog_filtered, + cv2.HOUGH_GRADIENT, + dp=1.0, + minDist=min_dist, + param1=50, + param2=param2, + minRadius=min_r, + maxRadius=max_r, + ) + + # Validate detected circles + if circles is not None: + circles = np.uint16(np.around(circles)) + valid_circles = [] + + gray_original = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + + for x, y, r in circles[0, :]: + if y >= filtered_mask.shape[0] or x >= filtered_mask.shape[1]: + continue + + if filtered_mask[y, x] > 0 and gray_original[y, x] < 200: + valid_circles.append((x, y, r)) + + circles = np.array([valid_circles], dtype=np.uint16) if valid_circles else None + + return circles + + def detect_traditional( + self, + image: np.ndarray, + foreground_mask: np.ndarray, + view_type: ViewType, + ) -> Optional[np.ndarray]: + """ + Detect markers using thresholding + morphology + Hough transform. + """ + a_uint8 = self._preprocess_image(image) + + fg_pixels = a_uint8[foreground_mask > 0] + percentile_thresh = min(np.percentile(fg_pixels, 5), 128) if len(fg_pixels) > 0 else 128 + _, binary = cv2.threshold(a_uint8, percentile_thresh, 255, cv2.THRESH_BINARY_INV) + binary = cv2.bitwise_and(binary, binary, mask=foreground_mask) + + morph_mask = self._apply_morphological_operations(binary) + filtered_mask = self._filter_contours(morph_mask, view_type, use_frst=False) + + return self._apply_dog_and_hough(image, a_uint8, filtered_mask, view_type, use_frst=False) + + def detect_frst( + self, + image: np.ndarray, + foreground_mask: np.ndarray, + view_type: ViewType, + ) -> Optional[np.ndarray]: + """ + Detect markers using FRST (Fast Radial Symmetry Transform). + """ + a_uint8 = self._preprocess_image(image) + + # Enhance circular structures + tophat_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (25, 25)) + tophat = cv2.morphologyEx(a_uint8, cv2.MORPH_TOPHAT, tophat_kernel) + frst_input = cv2.subtract(a_uint8, tophat) + + radii_range = list(range(15, 50, 5)) + frst = FRST() + frst_response = frst.run(frst_input, radii_range, alpha=1.0, beta=0.1) + + # Apply foreground mask + eroded_mask = cv2.erode( + foreground_mask, + cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (25, 25)), + iterations=1, + ) + frst_response = cv2.bitwise_and( + frst_response, frst_response, mask=eroded_mask.astype(np.uint8) + ) + + frst_normalized = cv2.normalize(frst_response, None, 0, 255, cv2.NORM_MINMAX).astype( + np.uint8 + ) + _, frst_binary = cv2.threshold(frst_normalized, 0, 255, cv2.THRESH_OTSU) + frst_binary = cv2.bitwise_and(frst_binary, frst_binary, mask=foreground_mask) + + morph_mask = self._apply_morphological_operations(frst_binary) + filtered_mask = self._filter_contours(morph_mask, view_type, use_frst=True) + + return self._apply_dog_and_hough(image, a_uint8, filtered_mask, view_type, use_frst=True) + + def _combine_results( + self, + view_type: ViewType, + circles_traditional: Optional[np.ndarray], + circles_frst: Optional[np.ndarray], + ) -> Optional[np.ndarray]: + """ + Merge circles from both methods, removing overlaps. + """ + combined = [] + + if circles_traditional is not None: + combined.extend( + (float(x), float(y), float(r), "traditional") for x, y, r in circles_traditional[0] + ) + if circles_frst is not None: + combined.extend((float(x), float(y), float(r), "frst") for x, y, r in circles_frst[0]) + + if not combined: + return None + + if len({m for _, _, _, m in combined}) == 1: + return np.array([[(x, y, r) for x, y, r, _ in combined]], dtype=np.uint16) + + # Remove overlaps, keep larger circles + unique = [] + for x1, y1, r1, method1 in sorted(combined, key=lambda c: c[2], reverse=True): + overlap = False + for i, (x2, y2, r2, method2) in enumerate(unique): + dist = np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) + min_dist = self.min_distance or view_type.marker_cfg.min_distance or 0 + if dist < (r1 + r2) or dist < min_dist: + if r1 > r2: + unique[i] = (x1, y1, r1, method1) + overlap = True + break + if not overlap: + unique.append((x1, y1, r1, method1)) + + if unique: + return np.array([[(x, y, r) for x, y, r, _ in unique]], dtype=np.uint16) + + return None + + @staticmethod + def draw_circles(image: np.ndarray, circles: np.ndarray) -> np.ndarray: + """ + Draw detected circles and their centers. + """ + output = image.copy() + if circles is not None: + for x, y, r in np.uint16(np.around(circles))[0, :]: + cv2.circle(output, (x, y), r, (0, 255, 0), 3) + cv2.circle(output, (x, y), 2, (255, 0, 0), 5) + return output + + def detect( + self, + image: np.ndarray, + foreground_mask: np.ndarray, + view_type: ViewType, + methods: List[DetectionMethod], + ) -> Tuple[Optional[np.ndarray], np.ndarray]: + """ + Detect green markers using selected methods and combine results. + """ + circles_traditional, circles_frst = None, None + + if DetectionMethod.TRADITIONAL in methods: + logging.info("Running traditional green marker detection...") + circles_traditional = self.detect_traditional(image, foreground_mask, view_type) + + if DetectionMethod.FRST in methods: + logging.info("Running FRST green marker detection...") + circles_frst = self.detect_frst(image, foreground_mask, view_type) + + circles = self._combine_results(view_type, circles_traditional, circles_frst) + output_image = self.draw_circles(image, circles) + + return circles, output_image diff --git a/src/processing_models/electrode/marker_labeler.py b/src/processing_models/electrode/marker_labeler.py new file mode 100644 index 0000000..aea4926 --- /dev/null +++ b/src/processing_models/electrode/marker_labeler.py @@ -0,0 +1,499 @@ +import cv2 +import logging +import numpy as np +import matplotlib.pyplot as plt + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +from .view_type import ViewType + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +@dataclass +class RegionConfig: + center: Union[float, Tuple[float, float]] + tolerance: Optional[float] = None + angle: Optional[float] = None + + +class MarkerLabeler: + + BACK_TOP_VERTICAL_LABELS = ["Pz", "POz", "Oz"] + BACK_TOP_HORIZONTAL_LABELS = ["O1", "Oz", "O2"] + FRONT_TOP_LABELS = ["Fz"] + FRONT_RIGHT_LABELS = ["Fp1"] + FRONT_LEFT_LABELS = ["Fp2"] + BACK_RIGHT_LABELS = ["TP10"] + BACK_LEFT_LABELS = ["TP9"] + + def __init__(self): + """ + Initialize the labeler that handles labeling of EEG markers based on spatial alignment and heuristics. + """ + + def _find_tightest_group( + self, candidates: np.ndarray, axis: int, group_size: int + ) -> np.ndarray: + """ + Find the tightest group of markers along an axis. + """ + best_group = None + min_span = float("inf") + + for i in range(len(candidates) - group_size + 1): + group = candidates[i : i + group_size] + span = group[-1, axis] - group[0, axis] + + if span < min_span: + min_span = span + best_group = group + + return best_group + + def _find_aligned_markers( + self, + markers: np.ndarray, + axis: int, + center: float, + tolerance: float, + min_count: int, + max_count: int, + ) -> np.ndarray: + """ + Find markers aligned along a specific axis. + """ + # Filter markers within tolerance band + perp_axis = 1 - axis + candidates = markers[np.abs(markers[:, perp_axis] - center) < tolerance] + + if len(candidates) == 0: + return np.array([]) + + # Sort along the alignment axis + candidates = candidates[np.argsort(candidates[:, axis])] + + # If fewer candidates than needed, return none + if len(candidates) < min_count: + return np.array([]) + + # If more candidates than needed, find tightest group + if len(candidates) > max_count: + candidates = self._find_tightest_group(candidates, axis, max_count) + + return candidates + + def _filter_back_top_vertical_markers( + self, + markers: np.ndarray, + width: int, + tolerance_ratio_x: float = 0.05, + ) -> Tuple[np.ndarray, list[str], RegionConfig]: + """ + Select vertically aligned markers around the image center. + """ + center_x = width / 2 + tolerance_x = width * tolerance_ratio_x + + selected = self._find_aligned_markers( + markers=markers, + axis=1, # Sort by y (vertical) + center=center_x, + tolerance=tolerance_x, + min_count=3, + max_count=3, + ) + + labels = self.BACK_TOP_VERTICAL_LABELS[: len(selected)] + config = RegionConfig(center=center_x, tolerance=tolerance_x) + + return selected, labels, config + + def _filter_back_top_horizontal_markers( + self, + markers: np.ndarray, + height: int, + position_ratio: float = 0.8, + tolerance_ratio_y: float = 0.1, + ) -> Tuple[np.ndarray, list[str], RegionConfig]: + """ + Select horizontally aligned markers at a specific height. + """ + center_y = height * position_ratio + tolerance_y = height * tolerance_ratio_y + + selected = self._find_aligned_markers( + markers=markers, + axis=0, # Sort by x (horizontal) + center=center_y, + tolerance=tolerance_y, + min_count=3, + max_count=3, + ) + + labels = self.BACK_TOP_HORIZONTAL_LABELS[: len(selected)] + config = RegionConfig(center=center_y, tolerance=tolerance_y) + + return selected, labels, config + + def _find_angled_electrode( + self, + candidates: np.ndarray, + ref_x: float, + ref_y: float, + target_angle: float, + angle_tolerance: float, + left: bool = False, + right: bool = False, + ) -> Optional[np.ndarray]: + """ + Find electrode at a specific angle from reference point. + """ + if len(candidates) == 0: + return None + + best_marker = None + min_angle_diff = float("inf") + + for marker in candidates: + dx = float(marker[0]) - float(ref_x) + dy = float(marker[1]) - float(ref_y) + + # Calculate angle from horizontal + # For left side: use negative dx to get positive angle + # For right side: use positive dx + if left: + angle = np.arctan2(dy, -dx) + elif right: + angle = np.arctan2(dy, dx) + else: + return None + + # Check if angle is within tolerance + angle_diff = abs(angle - target_angle) + + if angle_diff <= angle_tolerance and angle_diff < min_angle_diff: + min_angle_diff = angle_diff + best_marker = marker + + return best_marker + + def _filter_front_markers( + self, + markers: np.ndarray, + width: int, + tolerance_ratio_x: float = 0.05, + tolerance_angle_degrees: float = 15.0, + ) -> Tuple[np.ndarray, list[str], RegionConfig, RegionConfig]: + """ + Select markers that form a triangle near the image center. + """ + selected = [] + labels = [] + center_x = width / 2 + tolerance_x = width * tolerance_ratio_x + + selected_vertical = self._find_aligned_markers( + markers=markers, + axis=1, # Sort by y (vertical) + center=center_x, + tolerance=tolerance_x, + min_count=1, + max_count=1, + ) + + if len(selected_vertical) == 0: + return np.array([]), [], None + + # Find angled markers below the top marker + top_marker = selected_vertical[0] + selected.append(top_marker) + labels.append(self.FRONT_TOP_LABELS[0]) + top_x, top_y = top_marker[0], top_marker[1] + target_angle_rad = np.radians(90.0 - 30.0) # 30 degrees from vertical + tolerance_angle_rad = np.radians(tolerance_angle_degrees) + + # Find left marker + left_candidates = markers[(markers[:, 0] < center_x) & (markers[:, 1] > top_y)] + left_marker = self._find_angled_electrode( + left_candidates, + top_x, + top_y, + target_angle_rad, + tolerance_angle_rad, + left=True, + ) + if left_marker is not None: + selected.append(left_marker) + labels.append(self.FRONT_LEFT_LABELS[0]) + + # Find right marker + right_candidates = markers[(markers[:, 0] > center_x) & (markers[:, 1] > top_y)] + right_marker = self._find_angled_electrode( + right_candidates, + top_x, + top_y, + target_angle_rad, + tolerance_angle_rad, + right=True, + ) + if right_marker is not None: + selected.append(right_marker) + labels.append(self.FRONT_RIGHT_LABELS[0]) + + # If left/right markers not found, return closest to the right/left on the same height + if left_marker is None and len(left_candidates) > 0: + closest_left = left_candidates[np.argmin(np.abs(left_candidates[:, 0] - top_x))] + + if right_marker is not None: + # Compare distances from center_x + dist_left = abs(closest_left[0] - center_x) + dist_right = abs(right_marker[0] - center_x) + if np.isclose(dist_left, dist_right, rtol=0.25): # tolerance 25% + selected.append(closest_left) + labels.append(self.FRONT_LEFT_LABELS[0]) + else: + selected.append(closest_left) + labels.append(self.FRONT_LEFT_LABELS[0]) + + if right_marker is None and len(right_candidates) > 0: + closest_right = right_candidates[np.argmin(np.abs(right_candidates[:, 0] - top_x))] + + if left_marker is not None: + # Compare distances from center_x + dist_right = abs(closest_right[0] - center_x) + dist_left = abs(left_marker[0] - center_x) + if np.isclose(dist_right, dist_left, rtol=0.25): # tolerance 25% + selected.append(closest_right) + labels.append(self.FRONT_RIGHT_LABELS[0]) + else: + selected.append(closest_right) + labels.append(self.FRONT_RIGHT_LABELS[0]) + + config = RegionConfig(center=(top_x, top_y), tolerance=None, angle=target_angle_rad) + + return np.array(selected), labels, config + + def _filter_back_side_markers( + self, + markers: np.ndarray, + width: int, + height: int, + left: bool = False, + right: bool = False, + ) -> Tuple[np.ndarray, list[str], RegionConfig]: + """ + Select vertically aligned markers around the image center. + """ + selected = [] + labels = [] + + center_x = width / 2 + center_y = height / 2 + + # Split markers into left and right halves + left_markers = markers[(markers[:, 0] < center_x) & (markers[:, 1] > center_y)] + right_markers = markers[(markers[:, 0] >= center_x) & (markers[:, 1] > center_y)] + + # Find the marker with the largest y in each half + left_marker = left_markers[np.argmax(left_markers[:, 1])] if len(left_markers) > 0 else None + right_marker = ( + right_markers[np.argmax(right_markers[:, 1])] if len(right_markers) > 0 else None + ) + + if left and left_marker is not None: + selected.append(left_marker) + labels.append(self.BACK_LEFT_LABELS[0]) + if right and right_marker is not None: + selected.append(right_marker) + labels.append(self.BACK_RIGHT_LABELS[0]) + + config = RegionConfig( + center=(int(center_x), int(center_y)), + tolerance=None, + angle=np.radians(45.0), + ) + + return selected, labels, config + + def _assign_labels( + self, + markers: np.ndarray, + selected: np.ndarray, + labels: list[str], + labeled_markers: np.ndarray, + ) -> None: + """ + Assign labels to selected markers in the labeled array. + """ + for marker, label in zip(selected, labels): + # Find index of this marker in original array + idx = np.where((markers == marker).all(axis=1))[0] + + if len(idx) > 0: + idx = idx[0] + # Don't overwrite if label already exists (e.g., "Z" from both filters) + if labeled_markers[idx, -1] is None or labeled_markers[idx, -1] == label: + labeled_markers[idx, -1] = label + + def label_markers( + self, + image: np.ndarray, + markers: np.ndarray, + view_type: ViewType, + visualize: bool = False, + ) -> np.ndarray: + """ + Label markers based on view type alignment. + """ + if len(markers) == 0: + return np.empty((0, 3), dtype=object) + logging.info("Labeling markers for view type: %s", view_type.name) + + height, width = image.shape[:2] + + # Configs for visualization + vert_config = None + horiz_config = None + angle_config = None + + # Initialize labeled markers array + labeled_markers = markers.copy() + + if view_type == ViewType.BACK_TOP: + # Filter and label vertical markers + vert_selected, vert_labels, vert_config = self._filter_back_top_vertical_markers( + markers, width + ) + self._assign_labels(markers, vert_selected, vert_labels, labeled_markers) + + # Filter and label horizontal markers + horiz_selected, horiz_labels, horiz_config = self._filter_back_top_horizontal_markers( + markers, height + ) + self._assign_labels(markers, horiz_selected, horiz_labels, labeled_markers) + elif view_type == ViewType.FRONT: + selected, labels, angle_config = self._filter_front_markers(markers, width, 0.05, 5.0) + self._assign_labels(markers, selected, labels, labeled_markers) + elif view_type == ViewType.BACK_RIGHT: + selected, labels, angle_config = self._filter_back_side_markers( + markers, width, height, right=True + ) + self._assign_labels(markers, selected, labels, labeled_markers) + elif view_type == ViewType.BACK_LEFT: + selected, labels, angle_config = self._filter_back_side_markers( + markers, width, height, left=True + ) + self._assign_labels(markers, selected, labels, labeled_markers) + else: + logging.warning("No labeling rules defined for view type: %s", view_type) + + if visualize: + MarkerLabeler.visualize(image, labeled_markers, vert_config, horiz_config, angle_config) + + return labeled_markers + + @staticmethod + def visualize( + image: np.ndarray, + labeled_markers: np.ndarray, + vert_config: RegionConfig, + horiz_config: RegionConfig, + angle_config: RegionConfig, + ) -> None: + """ + Visualize markers and tolerance regions. + """ + vis_image = image.copy() + + # Draw markers + for marker in labeled_markers: + x, y, _, label = marker + + if label is None: + # Unlabeled: red circle + cv2.circle(vis_image, (x, y), 8, (255, 0, 0), -1) + else: + # Labeled: green circle with text + cv2.circle(vis_image, (x, y), 12, (0, 255, 0), -1) + cv2.putText( + vis_image, + str(label), + (x + 10, y - 10), + cv2.FONT_HERSHEY_COMPLEX, + 1.0, + (0, 0, 0), + 2, + cv2.LINE_AA, + ) + + # Create figure + # plt.figure(figsize=(12, 10)) + + # Draw tolerance regions + if vert_config: + plt.axvline( + vert_config.center, + color="blue", + linestyle="--", + linewidth=1, + alpha=0.25, + label="Vertical center", + ) + plt.axvspan( + vert_config.center - vert_config.tolerance, + vert_config.center + vert_config.tolerance, + color="blue", + alpha=0.1, + ) + + if horiz_config: + plt.axhline( + horiz_config.center, + color="blue", + linestyle="--", + linewidth=1, + alpha=0.25, + label="Horizontal line", + ) + plt.axhspan( + horiz_config.center - horiz_config.tolerance, + horiz_config.center + horiz_config.tolerance, + color="blue", + alpha=0.1, + ) + + if angle_config: + # Draw angle line from center downwards + radius = image.shape[0] // 2 + dx = np.cos(angle_config.angle) * radius + dy = np.sin(angle_config.angle) * radius + + x_end_left = int(angle_config.center[0] - dx) + y_end_left = int(angle_config.center[1] + dy) + plt.plot( + [angle_config.center[0], x_end_left], + [angle_config.center[1], y_end_left], + color="blue", + linestyle="--", + linewidth=1, + alpha=0.25, + ) + x_end_right = int(angle_config.center[0] + dx) + y_end_right = int(angle_config.center[1] + dy) + plt.plot( + [angle_config.center[0], x_end_right], + [angle_config.center[1], y_end_right], + color="blue", + linestyle="--", + linewidth=1, + alpha=0.25, + ) + + # plt.imshow(vis_image) + # plt.axis("off") + # plt.tight_layout() + # plt.show() + return vis_image diff --git a/src/processing_models/electrode/params/__init__.py b/src/processing_models/electrode/params/__init__.py new file mode 100644 index 0000000..9057731 --- /dev/null +++ b/src/processing_models/electrode/params/__init__.py @@ -0,0 +1,19 @@ +from .illumination_correction_params import CLAHEParams +from .noise_reduction_params import BilateralFilterParams, GuidedFilterParams, NLMParams +from .processing_params import ProcessingParams +from .sharpening_params import ( + AdaptiveSharpeningParams, + SelectiveSharpeningParams, + UnsharpMaskingParams, +) + +__all__ = [ + "CLAHEParams", + "BilateralFilterParams", + "GuidedFilterParams", + "NLMParams", + "ProcessingParams", + "AdaptiveSharpeningParams", + "SelectiveSharpeningParams", + "UnsharpMaskingParams", +] diff --git a/src/processing_models/electrode/params/illumination_correction_params.py b/src/processing_models/electrode/params/illumination_correction_params.py new file mode 100644 index 0000000..53df288 --- /dev/null +++ b/src/processing_models/electrode/params/illumination_correction_params.py @@ -0,0 +1,16 @@ +from typing import Tuple +from dataclasses import dataclass + +from ..color_space import ColorSpace + + +@dataclass +class CLAHEParams: + """ + Parameters for contrast limited adaptive histogram equalization. + """ + + enabled: bool = True + clip_limit: float = 2.5 + tile_grid_size: Tuple[int, int] = (8, 8) + color_space: ColorSpace = ColorSpace.LAB diff --git a/src/processing_models/electrode/params/noise_reduction_params.py b/src/processing_models/electrode/params/noise_reduction_params.py new file mode 100644 index 0000000..3028f3d --- /dev/null +++ b/src/processing_models/electrode/params/noise_reduction_params.py @@ -0,0 +1,43 @@ +import cv2 + +from dataclasses import dataclass + +from ..color_space import ColorSpace + + +@dataclass +class BilateralFilterParams: + """ + Parameters for bilateral filtering. + """ + + enabled: bool = True + diameter: int = 9 + sigma_color: float = 75.0 + sigma_space: float = 75.0 + border_type: int = cv2.BORDER_REFLECT_101 + + +@dataclass +class GuidedFilterParams: + """ + Parameters for guided filtering. + """ + + enabled: bool = False + radius: int = 3 + epsilon: float = 1e-3 + color_space: ColorSpace = ColorSpace.LAB + + +@dataclass +class NLMParams: + """ + Parameters for non-local means denoising. + """ + + enabled: bool = False + filtering_strength: float = 5.0 + template_window_size: int = 7 + search_window_size: int = 21 + color_space: ColorSpace = ColorSpace.LAB diff --git a/src/processing_models/electrode/params/processing_params.py b/src/processing_models/electrode/params/processing_params.py new file mode 100644 index 0000000..6dbaf5c --- /dev/null +++ b/src/processing_models/electrode/params/processing_params.py @@ -0,0 +1,51 @@ +from typing import List, Tuple +from dataclasses import dataclass, field + +from ..color_space import ColorSpace +from .illumination_correction_params import CLAHEParams +from .noise_reduction_params import BilateralFilterParams, GuidedFilterParams, NLMParams +from .sharpening_params import ( + AdaptiveSharpeningParams, + SelectiveSharpeningParams, + UnsharpMaskingParams, +) + + +@dataclass +class ProcessingParams: + """ + Parameters for image preprocessing used in view loader. + """ + + # General preprocessing + target_size: Tuple[int, int] = (1024, 1024) + gray_world: bool = True + normalize: bool = False + + # Noise reduction + bilateral: BilateralFilterParams = field(default_factory=BilateralFilterParams) + guided: GuidedFilterParams = field(default_factory=GuidedFilterParams) + nlm: NLMParams = field(default_factory=NLMParams) + + # Illumination correction + clahe: CLAHEParams = field(default_factory=CLAHEParams) + + # Sharpening + adaptive: AdaptiveSharpeningParams = field(default_factory=AdaptiveSharpeningParams) + selective: SelectiveSharpeningParams = field(default_factory=SelectiveSharpeningParams) + unsharp: UnsharpMaskingParams = field(default_factory=UnsharpMaskingParams) + + # Color space processing + target_color_spaces: List[ColorSpace] = None + + # Edge enhancement for DoG + Hough + edge_enhancement: bool = True + unsharp_mask_strength: float = 1.5 + + # Superpixel preprocessing + superpixel_preprocess: bool = True + median_filter_size: int = 5 + + def __post_init__(self): + if self.target_color_spaces is None: + self.target_color_spaces = [ColorSpace.RGB, ColorSpace.HSV, ColorSpace.LAB] diff --git a/src/processing_models/electrode/params/sharpening_params.py b/src/processing_models/electrode/params/sharpening_params.py new file mode 100644 index 0000000..90521a8 --- /dev/null +++ b/src/processing_models/electrode/params/sharpening_params.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass + + +@dataclass +class SharpeningParamsBase: + """ + Base parameters for sharpening methods. + """ + + sigma: float = 1.5 + strength: float = 2.5 + threshold: float = 10.0 + + +@dataclass +class AdaptiveSharpeningParams(SharpeningParamsBase): + """ + Parameters for adaptive sharpening. + """ + + enabled: bool = False + + +@dataclass +class SelectiveSharpeningParams(SharpeningParamsBase): + """ + Parameters for selective sharpening. + """ + + enabled: bool = False + + +@dataclass +class UnsharpMaskingParams(SharpeningParamsBase): + """ + Parameters for unsharp masking. + """ + + enabled: bool = True diff --git a/src/processing_models/electrode/util/__init__.py b/src/processing_models/electrode/util/__init__.py new file mode 100644 index 0000000..5910f2e --- /dev/null +++ b/src/processing_models/electrode/util/__init__.py @@ -0,0 +1,15 @@ +from .background_mask_util import BackgroundMaskUtil +from .color_enhancement_util import ColorEnhancementUtil +from .color_quantization_util import ColorQuantizationUtil +from .illumination_correction_util import IlluminationCorrectionUtil +from .noise_reduction_util import NoiseReductionUtil +from .sharpening_util import SharpeningUtil + +__all__ = [ + "BackgroundMaskUtil", + "ColorEnhancementUtil", + "ColorQuantizationUtil", + "IlluminationCorrectionUtil", + "NoiseReductionUtil", + "SharpeningUtil", +] diff --git a/src/processing_models/electrode/util/background_mask_util.py b/src/processing_models/electrode/util/background_mask_util.py new file mode 100644 index 0000000..88ab138 --- /dev/null +++ b/src/processing_models/electrode/util/background_mask_util.py @@ -0,0 +1,82 @@ +import cv2 +import logging +import numpy as np + +from typing import Optional, Tuple + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +class BackgroundMaskUtil: + @staticmethod + def analyze_white_background(image: np.ndarray) -> int: + """ + Analyze white background characteristics in an RGB image to suggest an optimal sensitivity. + """ + hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) + + # Strict white mask (very bright, very low saturation) + strict_white_mask = cv2.inRange(hsv, np.array([0, 0, 240]), np.array([255, 15, 255])) + + if np.count_nonzero(strict_white_mask) == 0: + return 40 # fallback if no white region is detected + + # Extract S and V channels where mask is white + s_channel = hsv[:, :, 1][strict_white_mask > 0] + v_channel = hsv[:, :, 2][strict_white_mask > 0] + + # Compute statistics more efficiently + s_mean, s_std = cv2.meanStdDev(s_channel) + v_mean, v_std = cv2.meanStdDev(v_channel) + + s_max = int(np.max(s_channel)) + v_min = int(np.min(v_channel)) + + # Base sensitivity depending on background brightness/variation + if s_max < 10 and v_min > 245: + base = 8 # pure white + elif s_max < 25 and v_min > 220: + base = 20 # paper white + elif s_max < 40 and v_min > 180: + base = 35 # lit but uneven + else: + base = 50 # shadowed + + # Adjustment based on variation + adjustment = min(int(s_std * 2 + v_std * 0.5), 20) + sensitivity = base + adjustment + + return max(0, min(100, sensitivity)) # clamp + + @staticmethod + def generate_background_mask( + image: np.ndarray, + sensitivity: Optional[int] = None, + kernel_size: Tuple[int, int] = (25, 25), + ) -> np.ndarray: + """ + Generate a binary mask for white or near-white backgrounds. + """ + if sensitivity is None: + sensitivity = BackgroundMaskUtil.analyze_white_background(image) + logging.info(f"Auto-detected sensitivity: {sensitivity}") + + sensitivity = max(0, min(100, sensitivity)) + + hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) + + # Define thresholds for white/near-white detection + lower_white = np.array([0, 0, 255 - sensitivity]) + upper_white = np.array([255, sensitivity, 255]) + mask = cv2.inRange(hsv, lower_white, upper_white) + + if kernel_size[0] > 0 and kernel_size[1] > 0: + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, kernel_size) + + # First remove small noise + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) + # Then close small gaps + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + + return mask diff --git a/src/processing_models/electrode/util/color_enhancement_util.py b/src/processing_models/electrode/util/color_enhancement_util.py new file mode 100644 index 0000000..6513a1e --- /dev/null +++ b/src/processing_models/electrode/util/color_enhancement_util.py @@ -0,0 +1,130 @@ +import cv2 +import numpy as np + +from typing import Tuple, Literal + + +class ColorEnhancementUtil: + + @staticmethod + def enhance_green_channel( + image: np.ndarray, + boost_factor: float, + max_value: float, + preserve_luminance: bool, + ) -> np.ndarray: + """ + Enhance the green channel with controlled boosting. + """ + # Convert to float for processing + result = image.astype(np.float32) + + # Preserve luminance option + if preserve_luminance: + # Calculate original luminance + original_luminance = np.mean(result, axis=2) + + # Boost green channel + result[:, :, 1] = np.minimum(result[:, :, 1] * boost_factor, max_value) + + # Recalculate luminance and normalize + new_luminance = np.mean(result, axis=2) + result *= (original_luminance / new_luminance)[:, :, np.newaxis] + + else: + # Simple green channel boost + result[:, :, 1] = np.minimum(result[:, :, 1] * boost_factor, max_value) + + return result.astype(np.uint8) + + @staticmethod + def enhance_green_difference( + image: np.ndarray, + excess_boost: float, + normalization_method: Literal["minmax", "zscore"], + ) -> np.ndarray: + """ + Enhance green by emphasizing its difference from other channels. + """ + # Convert to float for processing + result = image.astype(np.float32) + + # Calculate green channel excess + g_excess = result[:, :, 1] - (result[:, :, 0] + result[:, :, 2]) / 2 + + # Normalize green excess + if normalization_method == "minmax": + # Min-Max normalization + if g_excess.max() > g_excess.min(): + g_excess_normalized = ( + (g_excess - g_excess.min()) / (g_excess.max() - g_excess.min()) * 255 + ) + else: + g_excess_normalized = np.zeros_like(g_excess) + elif normalization_method == "zscore": + # Z-score normalization + mean, std = np.mean(g_excess), np.std(g_excess) + g_excess_normalized = ((g_excess - mean) / (std + 1e-8)) * 64 + 128 + else: + raise ValueError( + f"Unsupported normalization method: {normalization_method}. Use 'minmax' or 'zscore'." + ) + + # Enhance green channel + enhanced = result.copy() + enhanced[:, :, 1] = np.minimum(enhanced[:, :, 1] + g_excess_normalized * excess_boost, 255) + + return enhanced.astype(np.uint8) + + @staticmethod + def enhance_green_in_hsv( + image: np.ndarray, + green_hue: float, + hue_width: float, + saturation_boost: float, + value_boost: float, + ) -> np.ndarray: + """ + Enhance green regions in HSV color space. + """ + # Convert to HSV + hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32) + + # Create weight map based on proximity to green hue + hue_proximity = np.exp(-0.5 * ((hsv[:, :, 0] - green_hue) / hue_width) ** 2) + + # Boost saturation for pixels with hue close to green + hsv[:, :, 1] = np.minimum( + hsv[:, :, 1] + hsv[:, :, 1] * hue_proximity * saturation_boost, 255 + ) + + # Boost value for pixels with hue close to green + hsv[:, :, 2] = np.minimum(hsv[:, :, 2] + hsv[:, :, 2] * hue_proximity * value_boost, 255) + + return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB) + + @staticmethod + def enhance_green_in_lab( + image: np.ndarray, + green_lab: Tuple[float, float, float] = (0, -128, 128), + distance_scale: float = 50.0, + a_channel_boost: float = 0.5, + ) -> np.ndarray: + """ + Enhance green regions in Lab color space. + """ + # Convert to Lab + lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB).astype(np.float32) + + # Calculate distance from reference green point + lab_distance = np.linalg.norm(lab - np.array(green_lab), axis=2) + + # Create proximity weight map + lab_proximity = np.exp(-0.5 * (lab_distance / distance_scale) ** 2) + + # Boost a* channel for pixels close to green + lab[:, :, 1] = np.minimum( + lab[:, :, 1] + lab[:, :, 1] * lab_proximity * a_channel_boost, 255 + ) + + return cv2.cvtColor(lab.astype(np.uint8), cv2.COLOR_LAB2RGB) diff --git a/src/processing_models/electrode/util/color_quantization_util.py b/src/processing_models/electrode/util/color_quantization_util.py new file mode 100644 index 0000000..d171824 --- /dev/null +++ b/src/processing_models/electrode/util/color_quantization_util.py @@ -0,0 +1,331 @@ +import cv2 +import numpy as np + +from typing import Optional, List, Literal, Tuple + +from ..color_space import ColorSpace + + +class ColorQuantizationUtil: + + @staticmethod + def kmeans_quantization( + image: np.ndarray, + n_colors: int, + color_space: ColorSpace, + ) -> np.ndarray: + """ + Quantize colors using K-means clustering with flexible color space options. + """ + # Prepare image for clustering + if color_space == ColorSpace.LAB: + # Convert to LAB color space for perceptually uniform quantization + pixels = cv2.cvtColor(image, cv2.COLOR_RGB2LAB).astype(np.float32) + elif color_space == ColorSpace.RGB: + # Use RGB color space + pixels = image.astype(np.float32) + else: + raise ValueError(f"K-means not supported for {color_space} color space.") + + # Reshape pixels for K-means (N_pixels x channels) + h, w = image.shape[:2] + pixels_reshaped = pixels.reshape(-1, 3) + + # Perform K-means clustering + criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2) + _, labels, centers = cv2.kmeans( + pixels_reshaped, n_colors, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS + ) + + # Reconstruct image + labels = labels.flatten() + quantized = centers[labels].reshape(h, w, 3) + + # Convert back to original color space if LAB was used + if color_space == ColorSpace.LAB: + quantized = cv2.cvtColor(quantized.astype(np.uint8), cv2.COLOR_LAB2RGB) + else: + quantized = quantized.astype(np.uint8) + + return quantized + + @staticmethod + def lab_hsv_quantization( + image: np.ndarray, + n_colors: int, + foreground_mask: Optional[np.ndarray], + verbose: bool = False, + ) -> np.ndarray: + """ + Advanced color quantization with emphasis on green regions. + """ + # Default max recursion depth + max_recursion = 24 + + # Default green detection criteria + green_criteria = { + "hue_min": 30, + "hue_max": 90, + "saturation_min": 64, + "value_min": 64, + } + + def _lab_quantization( + image: np.ndarray, n_colors: int, foreground_mask: np.ndarray + ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: + """ + Quantize image in LAB color space. + """ + # Convert to LAB + lab_image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB).astype(np.float32) + + # Extract foreground pixels in LAB space + foreground_lab = lab_image[foreground_mask] + + # Check for valid foreground pixels + if len(foreground_lab) == 0: + return None, None + + # K-means clustering + try: + criteria = ( + cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, + 100, + 0.2, + ) + _, labels, centers = cv2.kmeans( + foreground_lab, + n_colors, + None, + criteria, + 10, + cv2.KMEANS_RANDOM_CENTERS, + ) + return labels, centers + except cv2.error: + return None, None + + def _hsv_quantization( + image: np.ndarray, n_colors: int, foreground_mask: np.ndarray + ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: + """ + Quantize image with special handling for green regions. + """ + # Convert to LAB and HSV + lab_image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB).astype(np.float32) + hsv_image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) + + # Green region detection + green_mask = ( + (hsv_image[:, :, 0] >= green_criteria["hue_min"]) + & (hsv_image[:, :, 0] <= green_criteria["hue_max"]) + & (hsv_image[:, :, 1] >= green_criteria["saturation_min"]) + & (hsv_image[:, :, 2] >= green_criteria["value_min"]) + ) + + # Combined mask: foreground AND green + green_foreground_mask = foreground_mask & green_mask + + # Prepare pixels + foreground_lab = lab_image[foreground_mask] + + if len(foreground_lab) == 0: + return None, None + + # If green foreground pixels exist, weight them + if np.any(green_foreground_mask): + green_lab_pixels = lab_image[green_foreground_mask] + + if len(green_lab_pixels) > 0: + # Combine regular foreground with duplicated green pixels + weighted_pixels = np.vstack([foreground_lab] + [green_lab_pixels] * 5) + + try: + criteria = ( + cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, + 100, + 0.2, + ) + _, labels_weighted, centers_weighted = cv2.kmeans( + weighted_pixels, + n_colors, + None, + criteria, + 10, + cv2.KMEANS_RANDOM_CENTERS, + ) + + # Recompute cluster assignments + distances = np.zeros((len(foreground_lab), n_colors), dtype=np.float32) + for i in range(n_colors): + distances[:, i] = np.sum( + (foreground_lab - centers_weighted[i]) ** 2, axis=1 + ) + labels = np.argmin(distances, axis=1) + + return labels, centers_weighted + except cv2.error: + return None, None + + return None, None + + def _detect_green_clusters(centers: np.ndarray) -> np.ndarray: + """ + Identify green clusters based on HSV color space. + """ + # Convert centers to HSV + centers_hsv = np.zeros((centers.shape[0], 3), dtype=np.uint8) + for i, center in enumerate(centers): + # Create a temporary 1x1 image with the LAB color + temp_lab = np.zeros((1, 1, 3), dtype=np.uint8) + temp_lab[0, 0] = center.astype(np.uint8) + + # Convert to HSV + temp_rgb = cv2.cvtColor(temp_lab, cv2.COLOR_LAB2RGB) + temp_hsv = cv2.cvtColor(temp_rgb, cv2.COLOR_RGB2HSV) + centers_hsv[i] = temp_hsv[0, 0] + + # Calculate "greenness" metric for each cluster center + green_clusters = centers_hsv[ + (centers_hsv[:, 0] >= green_criteria["hue_min"]) + & (centers_hsv[:, 0] <= green_criteria["hue_max"]) + & (centers_hsv[:, 1] >= green_criteria["saturation_min"]) + & (centers_hsv[:, 2] >= green_criteria["value_min"]) + ] + + return green_clusters + + # Prepare foreground mask + if foreground_mask is None: + foreground_mask = np.ones(image.shape[:2], dtype=bool) + + # Extract foreground indices + foreground_indices = np.where(foreground_mask) + + # LAB quantization + verbose and print("Quantizing in LAB color space...") + labels, centers = _lab_quantization(image, n_colors, foreground_mask) + + # Check green clusters + if labels is None or centers is None: + return image + + green_clusters = _detect_green_clusters(centers) + + # If not enough green clusters, try HSV approach + if len(green_clusters) < 1: + verbose and print("No green clusters found, trying quantization in HSV color space...") + labels, centers = _hsv_quantization(image, n_colors, foreground_mask) + + if labels is None or centers is None: + # If quantization fails, increase colors or return original + return ( + ColorQuantizationUtil.lab_hsv_quantization(image, n_colors + 2, foreground_mask) + if n_colors < max_recursion + else image + ) + + green_clusters = _detect_green_clusters(centers) + + # Still no green clusters + if len(green_clusters) < 1: + verbose and print(f"No green clusters found, trying {n_colors + 2} colors...") + return ( + ColorQuantizationUtil.lab_hsv_quantization(image, n_colors + 2, foreground_mask) + if n_colors < max_recursion + else image + ) + + # Create output image + result = image.copy() + + # Map foreground pixels to quantized colors + for i, (y, x) in enumerate(zip(foreground_indices[0], foreground_indices[1])): + center_lab = centers[labels[i]] + + # Convert LAB center to RGB + temp_lab = np.zeros((1, 1, 3), dtype=np.uint8) + temp_lab[0, 0] = center_lab.astype(np.uint8) + temp_rgb = cv2.cvtColor(temp_lab, cv2.COLOR_LAB2RGB) + result[y, x] = temp_rgb[0, 0] + + return result + + @staticmethod + def median_cut_quantization( + image: np.ndarray, + n_colors: int, + color_reduction_method: Literal["avg", "representative"], + ) -> np.ndarray: + """ + Apply median cut algorithm for color quantization. + """ + + def _median_cut_recursive(pixels: np.ndarray, depth: int, method: str) -> List[np.ndarray]: + """ + Recursively apply median cut to color space. + """ + # Base case: reached depth or single pixel + if depth == 0 or len(pixels) <= 1: + if method == "avg": + return np.mean(pixels, axis=0) + else: # representative + return pixels[len(pixels) // 2] + + # Find channel with highest range + ranges = np.ptp(pixels, axis=0) + channel = np.argmax(ranges) + + # Sort by this channel + sorted_pixels = pixels[pixels[:, channel].argsort()] + + # Split at median + mid = len(sorted_pixels) // 2 + + # Recursively process both halves + return [ + _median_cut_recursive(sorted_pixels[:mid], depth - 1, method), + _median_cut_recursive(sorted_pixels[mid:], depth - 1, method), + ] + + # Flatten image and prepare for recursion + pixels = image.reshape(-1, 3).astype(np.float32) + + # Determine recursion depth + depth = int(np.ceil(np.log2(n_colors))) + + # Flatten palette recursively + def _flatten_palette(colors): + palette = [] + + def _flatten(item): + if isinstance(item, np.ndarray): + palette.append(item) + else: + for subitem in item: + _flatten(subitem) + + _flatten(colors) + return np.array(palette) + + # Generate palette + palette = _flatten_palette(_median_cut_recursive(pixels, depth, color_reduction_method)) + + # Limit to requested number of colors + if len(palette) > n_colors: + # Use K-means to further reduce colors if needed + criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2) + _, _, centers = cv2.kmeans( + palette, n_colors, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS + ) + palette = centers + + # Map pixels to closest palette color + quantized = np.zeros_like(pixels) + for i, pixel in enumerate(pixels): + # Find closest color in palette + distances = np.sum((palette - pixel) ** 2, axis=1) + closest_color_idx = np.argmin(distances) + quantized[i] = palette[closest_color_idx] + + return quantized.reshape(image.shape).astype(np.uint8) diff --git a/src/processing_models/electrode/util/illumination_correction_util.py b/src/processing_models/electrode/util/illumination_correction_util.py new file mode 100644 index 0000000..7e0f0fc --- /dev/null +++ b/src/processing_models/electrode/util/illumination_correction_util.py @@ -0,0 +1,60 @@ +import cv2 +import numpy as np + +from typing import Literal, Tuple + +from ..color_space import ColorSpace + + +class IlluminationCorrectionUtil: + + @staticmethod + def clahe_equalization( + image: np.ndarray, + clip_limit: float, + tile_grid_size: Tuple[int, int], + color_space: Literal[ColorSpace.LAB, ColorSpace.YUV], + ) -> np.ndarray: + """ + Apply Contrast Limited Adaptive Histogram Equalization (CLAHE) + to normalize image lighting. + + Reference: https://ieeexplore.ieee.org/document/109340 + Alternatives: + - https://arxiv.org/pdf/2004.07945 (Adaptive Local Contrast Normalization) + - https://www.ipol.im/pub/art/2014/107/article_lr.pdf (Multi-Scale Retinex) + """ + # CLAHE object + clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size) + + # Processing based on color space + if color_space == ColorSpace.LAB: + # Convert to LAB + lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) + + # Split channels + l, a, b = cv2.split(lab) + + # Apply CLAHE to L channel + l_clahe = clahe.apply(l) + + # Merge and convert back to RGB + lab_clahe = cv2.merge((l_clahe, a, b)) + return cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB) + + elif color_space == ColorSpace.YUV: + # Convert to YUV + yuv = cv2.cvtColor(image, cv2.COLOR_RGB2YUV) + + # Split channels + y, u, v = cv2.split(yuv) + + # Apply CLAHE to luminance channel + y_eq = clahe.apply(y) + + # Merge and convert back to RGB + yuv_eq = cv2.merge((y_eq, u, v)) + return cv2.cvtColor(yuv_eq, cv2.COLOR_YUV2RGB) + + else: + raise ValueError(f"CLAHE not supported for {color_space} color space.") diff --git a/src/processing_models/electrode/util/noise_reduction_util.py b/src/processing_models/electrode/util/noise_reduction_util.py new file mode 100644 index 0000000..ed6fadb --- /dev/null +++ b/src/processing_models/electrode/util/noise_reduction_util.py @@ -0,0 +1,203 @@ +import cv2 +import numpy as np + +from ..color_space import ColorSpace + + +class NoiseReductionUtil: + + @staticmethod + def bilateral_filter( + image: np.ndarray, + diameter: int, + sigma_color: float, + sigma_space: float, + border_type: int, + ) -> np.ndarray: + """ + Apply bilateral filter to reduce noise while preserving edges. + + Reference: https://homepages.inf.ed.ac.uk/rbf/CVonline/LOCAL_COPIES/MANDUCHI1/Bilateral_Filtering.html + """ + # Apply bilateral filter + return cv2.bilateralFilter( + src=image, + d=diameter, + sigmaColor=sigma_color, + sigmaSpace=sigma_space, + borderType=border_type, + ) + + @staticmethod + def guided_filter( + image: np.ndarray, + radius: int, + epsilon: float, + color_space: ColorSpace, + ) -> np.ndarray: + """ + Apply guided filter for edge-preserving smoothing. + + References: https://link.springer.com/chapter/10.1007/978-3-642-15549-9_1 + """ + # Default guidance boost + guidance_boost = {"r": 1.0, "g": 1.3, "b": 1.0} + + # Convert image to float32 for processing + image_float = image.astype(np.float32) / 255.0 + + # Create guidance image with optional channel boosting + def _create_guidance(img: np.ndarray) -> np.ndarray: + guidance = img.copy() + guidance[:, :, 0] *= guidance_boost["r"] + guidance[:, :, 1] *= guidance_boost["g"] + guidance[:, :, 2] *= guidance_boost["b"] + return np.clip(guidance, 0, 1) + + # Filtering method selection + if color_space == ColorSpace.RGB: + # Convert to LAB for filtering + lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB).astype(np.float32) + + # Create guidance image + guide_rgb = _create_guidance(image_float) + gray_guide = ( + 0.1 * guide_rgb[:, :, 0] + 0.8 * guide_rgb[:, :, 1] + 0.1 * guide_rgb[:, :, 2] + ) + + # Normalize LAB channels + l = lab[:, :, 0] / 255.0 + a = (lab[:, :, 1] + 128) / 255.0 + b = (lab[:, :, 2] + 128) / 255.0 + + # Apply guided filter with different epsilon for each channel + l_filtered = cv2.ximgproc.guidedFilter(gray_guide, l, radius, epsilon) + a_filtered = cv2.ximgproc.guidedFilter(gray_guide, a, radius, epsilon * 1.5) + b_filtered = cv2.ximgproc.guidedFilter(gray_guide, b, radius, epsilon * 1.5) + + # Reconstruct LAB image + lab_filtered = np.zeros_like(lab) + lab_filtered[:, :, 0] = l_filtered * 255.0 + lab_filtered[:, :, 1] = a_filtered * 255.0 - 128 + lab_filtered[:, :, 2] = b_filtered * 255.0 - 128 + + return cv2.cvtColor(lab_filtered.astype(np.uint8), cv2.COLOR_LAB2RGB) + + elif color_space == ColorSpace.LAB: + # Convert to LAB + lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB).astype(np.float32) + + # Normalize channels + l = lab[:, :, 0] / 255.0 + a = (lab[:, :, 1] + 128) / 255.0 + b = (lab[:, :, 2] + 128) / 255.0 + + # Apply guided filter with different parameters + l_filtered = cv2.ximgproc.guidedFilter(l, l, radius, epsilon) + a_filtered = cv2.ximgproc.guidedFilter(a, a, radius, epsilon * 1.5) + b_filtered = cv2.ximgproc.guidedFilter(b, b, radius, epsilon * 1.5) + + # Reconstruct LAB image + lab_filtered = np.zeros_like(lab) + lab_filtered[:, :, 0] = l_filtered * 255.0 + lab_filtered[:, :, 1] = a_filtered * 255.0 - 128 + lab_filtered[:, :, 2] = b_filtered * 255.0 - 128 + + return cv2.cvtColor(lab_filtered.astype(np.uint8), cv2.COLOR_LAB2RGB) + + elif color_space == ColorSpace.GRAY: + # Create guidance image with green channel boost + guidance = _create_guidance(image_float) + gray_guide = ( + 0.2126 * guidance[:, :, 0] + 0.7152 * guidance[:, :, 1] + 0.0722 * guidance[:, :, 2] + ) + + # Filter each channel + result = np.zeros_like(image_float) + for i in range(3): + result[:, :, i] = cv2.ximgproc.guidedFilter( + gray_guide, image_float[:, :, i], radius, epsilon + ) + + return (result * 255).astype(np.uint8) + + else: + # Apply guided filter to each channel + result = np.zeros_like(image_float) + for i in range(3): + result[:, :, i] = cv2.ximgproc.guidedFilter( + image_float[:, :, i], image_float[:, :, i], radius, epsilon + ) + + return (result * 255).astype(np.uint8) + + @staticmethod + def nlm_denoising( + image: np.ndarray, + filtering_strength: float, + color_space: ColorSpace, + ) -> np.ndarray: + """ + Apply Non-Local Means denoising with adaptive channel processing. + + Reference: https://www.ipol.im/pub/art/2011/bcm_nlm/article.pdf + """ + # Perform channel-wise denoising + if color_space == ColorSpace.RGB: + # Default channel strength scaling + channel_strengths = { + "r": 1.0, # Red channel full strength + "g": 0.7, # Green channel slightly reduced + "b": 1.0, # Blue channel full strength + } + + # Split RGB channels + r, g, b = cv2.split(image) + + # Denoise each channel with scaled strength + r_denoised = cv2.fastNlMeansDenoising( + r, None, filtering_strength * channel_strengths["r"] + ) + g_denoised = cv2.fastNlMeansDenoising( + g, None, filtering_strength * channel_strengths["g"] + ) + b_denoised = cv2.fastNlMeansDenoising( + b, None, filtering_strength * channel_strengths["b"] + ) + + # Merge denoised channels + return cv2.merge((r_denoised, g_denoised, b_denoised)) + + elif color_space == ColorSpace.LAB: + # Default channel strength scaling + channel_strengths = { + "l": 1.0, # Luminance full strength + "a": 0.7, # Color-a channel reduced + "b": 0.7, # Color-b channel reduced + } + + # Convert to LAB color space + lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) + + # Split LAB channels + l, a, b = cv2.split(lab) + + # Denoise each channel with scaled strength + l_denoised = cv2.fastNlMeansDenoising( + l, None, filtering_strength * channel_strengths["l"] + ) + a_denoised = cv2.fastNlMeansDenoising( + a, None, filtering_strength * channel_strengths["a"] + ) + b_denoised = cv2.fastNlMeansDenoising( + b, None, filtering_strength * channel_strengths["b"] + ) + + # Merge denoised channels + lab_denoised = cv2.merge((l_denoised, a_denoised, b_denoised)) + + # Convert back to RGB color space + return cv2.cvtColor(lab_denoised, cv2.COLOR_LAB2RGB) + + else: + raise ValueError(f"NLM denoising not supported for {color_space} color space.") diff --git a/src/processing_models/electrode/util/sharpening_util.py b/src/processing_models/electrode/util/sharpening_util.py new file mode 100644 index 0000000..b184af8 --- /dev/null +++ b/src/processing_models/electrode/util/sharpening_util.py @@ -0,0 +1,123 @@ +import cv2 +import numpy as np + +from typing import Optional + + +class SharpeningUtil: + + @staticmethod + def adaptive_sharpen( + image: np.ndarray, + sigma: float, + strength: float, + threshold: float, + ) -> np.ndarray: + """ + Apply adaptive sharpening based on edge detection. + """ + # Convert to float for calculations + img_float = image.astype(np.float32) + + # Detect edges using Laplacian + if len(image.shape) == 3: # Color image + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + else: # Already grayscale + gray = image.copy() + + # Calculate edge map + laplacian = cv2.Laplacian(gray, cv2.CV_32F) + edge_map = np.abs(laplacian) + + # Create adaptive mask based on edges + mask = edge_map > threshold + mask = mask.astype(np.float32) + + # Smooth the mask to create gradual transitions + mask = cv2.GaussianBlur(mask, (0, 0), sigma * 2) + + # Apply unsharp mask algorithm + blurred = cv2.GaussianBlur(img_float, (0, 0), sigma) + sharpened = img_float + strength * (img_float - blurred) + + # Expand the mask to match image dimensions if it's a color image + if len(image.shape) == 3: + mask = np.expand_dims(mask, axis=2) + mask = np.repeat(mask, 3, axis=2) + + # Apply adaptive blending based on mask + result = img_float * (1 - mask) + sharpened * mask + + # Clip values to valid range and convert back to original type + return np.clip(result, 0, 255).astype(image.dtype) + + @staticmethod + def selective_sharpen( + image: np.ndarray, + sigma: float, + strength: float, + threshold: Optional[float], + ) -> np.ndarray: + """ + Apply selective sharpening with emphasis on t the green channel. + """ + # Default channel strengths if not provided + channel_strengths = ( + strength * 0.7, # Red + strength, # Green + strength * 0.7, # Blue + ) + + # Split channels + channels = cv2.split(image) + + # Sharpen each channel with different intensities + sharpened_channels = [ + SharpeningUtil.unsharp_masking(channels[i], sigma, channel_strengths[i], threshold) + for i in range(3) + ] + + # Merge channels + return cv2.merge(sharpened_channels) + + @staticmethod + def unsharp_masking( + image: np.ndarray, + sigma: float, + strength: float, + threshold: Optional[float], + ) -> np.ndarray: + """ + Apply unsharp masking for edge enhancement. + """ + # Input validation + if not isinstance(image, np.ndarray): + raise TypeError("Input must be a NumPy array.") + + # Ensure float processing + image_float = image.astype(np.float32) + + # Create blurred version + blurred = cv2.GaussianBlur(image_float, (0, 0), sigma) + + # Apply unsharp masking + sharpened = cv2.addWeighted(image_float, 1.0 + strength, blurred, -strength, 0) + + # Optional thresholding + if threshold is not None: + # Create edge mask + if image_float.ndim == 2: + # Grayscale image + edges = cv2.Laplacian(image_float, cv2.CV_32F) + mask = np.abs(edges) > threshold + sharpened = image_float + (sharpened - image_float) * mask + else: + # Color image + gray = cv2.cvtColor(image_float, cv2.COLOR_RGB2GRAY) + edges = cv2.Laplacian(gray, cv2.CV_32F) + mask = np.abs(edges) > threshold + mask = mask[:, :, np.newaxis] + sharpened = image_float + (sharpened - image_float) * mask + + # Clip to valid range + return np.clip(sharpened, 0, 255).astype(np.uint8) diff --git a/src/processing_models/electrode/view_loader.py b/src/processing_models/electrode/view_loader.py new file mode 100644 index 0000000..1cfed29 --- /dev/null +++ b/src/processing_models/electrode/view_loader.py @@ -0,0 +1,494 @@ +import cv2 +import os +import logging +import numpy as np +import matplotlib.pyplot as plt + +from pathlib import Path +from collections import defaultdict +from typing import Dict, List, Optional, Union + +from .view_type import ViewType +from .params import ProcessingParams +from .marker_labeler import MarkerLabeler +from .marker_detector import MarkerDetector +from .detection_method import DetectionMethod +from .electrode_detector import ElectrodeDetector +from .basic_electrode_detector import BasicElectrodeDetector +from concurrent.futures import ProcessPoolExecutor, as_completed +from .util import BackgroundMaskUtil, IlluminationCorrectionUtil, NoiseReductionUtil + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +class ViewLoader: + def __init__( + self, + views_path: Union[str, Path], + params: Optional[ProcessingParams] = None, + ) -> None: + """ + Initialize view loader for working with 2d images of 3d head scans. + """ + self.metadata = {} + self.images = self._load_images(views_path) + self.images_raw = self.images.copy() + self.preprocessed = {} + self.detected = defaultdict(dict) + self.params = params or ProcessingParams() + + if params is not None: + # Avoid preprocessing data in util tests + self.preprocess_data() + + def _load_images(self, views_path: Union[str, Path]) -> Dict[ViewType, np.ndarray]: + """ + Load a set of 2d views each captured from different engle. + """ + if not os.path.exists(views_path): + logging.error(f"Views folder not found: {views_path}") + return None + logging.info(f"Loading views from {views_path}") + + # Find images for each view type + images = {} + for view_type in ViewType: + image_path = f"{views_path}/{view_type.name}.png" + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + images[view_type] = image + + # Store metadata + self.metadata[view_type] = { + "path": image_path, + "shape": image.shape, + "dtype": str(image.dtype), + } + logging.info(f"Loaded {view_type.name}: {image.shape}") + + return images + + def preprocess_data(self) -> None: + """ + Comprehensive preprocessing for all loaded images. + """ + if self.images is None: + raise ValueError("Views not loded") + logging.info("Preprocessign data...") + + for view_type, image in self.images.items(): + # Basic preprocessing (common for all methods) + logging.info(f"Preprocessing {view_type.name}...") + preprocessed_variants = {"original_image": image.copy()} + + # Resize + preprocessed_image = cv2.resize( + image, self.params.target_size, interpolation=cv2.INTER_AREA + ) + + # Color constancy (Gray World assumption) + if self.params.gray_world: + preprocessed_float = preprocessed_image.astype(np.float32) + avg_b = np.mean(preprocessed_float[:, :, 0]) + avg_g = np.mean(preprocessed_float[:, :, 1]) + avg_r = np.mean(preprocessed_float[:, :, 2]) + + avg_gray = (avg_b + avg_g + avg_r) / 3.0 + + if avg_b > 0: + preprocessed_float[:, :, 0] *= avg_gray / avg_b + if avg_g > 0: + preprocessed_float[:, :, 1] *= avg_gray / avg_g + if avg_r > 0: + preprocessed_float[:, :, 2] *= avg_gray / avg_r + + preprocessed_image = np.clip(preprocessed_float, 0, 255).astype(np.uint8) + + # Noise reduction + if self.params.bilateral.enabled: + preprocessed_image = NoiseReductionUtil.bilateral_filter( + preprocessed_image, + self.params.bilateral.diameter, + self.params.bilateral.sigma_color, + self.params.bilateral.sigma_space, + self.params.bilateral.border_type, + ) + if self.params.guided.enabled: + preprocessed_image = NoiseReductionUtil.guided_filter( + preprocessed_image, + self.params.guided.radius, + self.params.guided.epsilon, + self.params.guided.color_space, + ) + if self.params.nlm.enabled: + preprocessed_image = NoiseReductionUtil.nlm_denoising( + preprocessed_image, + self.params.nlm.filtering_strength, + self.params.nlm.color_space, + ) + + # Illumination correction + if self.params.clahe.enabled: + preprocessed_image = IlluminationCorrectionUtil.clahe_equalization( + preprocessed_image, + self.params.clahe.clip_limit, + self.params.clahe.tile_grid_size, + self.params.clahe.color_space, + ) + + # Background mask generation + background_mask = BackgroundMaskUtil.generate_background_mask( + preprocessed_image, None, (25, 25) + ) + preprocessed_variants["background_mask"] = background_mask.copy() + preprocessed_variants["foreground_mask"] = cv2.bitwise_not(background_mask).copy() + + # Set white background to preprocessed image + mask = background_mask.astype(bool) + preprocessed_image[mask] = [255, 255, 255] + + # Normalization + if self.params.normalize: + preprocessed_image = preprocessed_image.astype(np.float32) / 255.0 + + preprocessed_variants["preprocessed_image"] = preprocessed_image.copy() + self.preprocessed[view_type] = preprocessed_variants.copy() + + def detect_markers( + self, + view_types: Optional[List[ViewType]] = None, + methods: Optional[List[DetectionMethod]] = None, + ) -> None: + """ + Detects markers in preprocessed images for given view types. + """ + if not self.preprocessed: + raise ValueError("View images are not preprocessed") + + if not view_types: + view_types = [view_type for view_type in ViewType if view_type.marker_cfg is not None] + + if not methods: + methods = [DetectionMethod.TRADITIONAL, DetectionMethod.FRST] + + logging.info("Detecting markers for view types: %s", [vt.name for vt in view_types]) + + detector = MarkerDetector() + for view_type in view_types: + markers, _ = detector.detect( + self.preprocessed[view_type]["preprocessed_image"].copy(), + self.preprocessed[view_type]["foreground_mask"].copy(), + view_type, + methods, + ) + markers_image = MarkerDetector.draw_circles( + self.preprocessed[view_type]["original_image"].copy(), markers + ) + + markers = markers.squeeze(0).copy() + markers = np.column_stack((markers, np.array([None] * markers.shape[0], dtype=object))) + self.detected[view_type][DetectionMethod.MARKER] = markers.copy() + self.preprocessed[view_type]["markers"] = markers.copy() + self.preprocessed[view_type]["markers_image"] = markers_image.copy() + + def label_markers( + self, + view_types: Optional[List[ViewType]] = None, + ) -> None: + """ + Labels markers in preprocessed images for given view types. + """ + if not self.preprocessed: + raise ValueError("View images are not preprocessed") + + if not view_types: + view_types = [view_type for view_type in ViewType if view_type.marker_cfg is not None] + + logging.info("Labeling markers for view types: %s", [vt.name for vt in view_types]) + + labeler = MarkerLabeler() + for view_type in view_types: + if "markers" not in self.preprocessed[view_type]: + logging.warning(f"No markers detected for {view_type.name}, skipping labeling.") + continue + + labeled_markers = labeler.label_markers( + self.preprocessed[view_type]["original_image"].copy(), + self.preprocessed[view_type]["markers"].copy(), + view_type, + visualize=False, + ) + self.detected[view_type][DetectionMethod.MARKER] = labeled_markers.copy() + self.preprocessed[view_type]["labeled_markers"] = labeled_markers.copy() + + def detect_electrodes_basic( + self, + view_types: Optional[List[ViewType]] = None, + methods: Optional[List[DetectionMethod]] = None, + ) -> None: + """ + Detects electrodes in preprocessed images for given view types. + """ + if not self.preprocessed: + raise ValueError("View images are not preprocessed") + + if not view_types: + view_types = [view_type for view_type in ViewType if view_type.marker_cfg is not None] + + if not methods: + methods = [DetectionMethod.TRADITIONAL, DetectionMethod.FRST] + + logging.info( + "Detecting electrodes (basic) for view types: %s", [vt.name for vt in view_types] + ) + + detector = BasicElectrodeDetector() + for view_type in view_types: + electrodes, _ = detector.detect( + self.preprocessed[view_type]["preprocessed_image"].copy(), + self.preprocessed[view_type]["foreground_mask"].copy(), + view_type, + methods, + ) + electrodes_image = BasicElectrodeDetector.draw_circles( + self.preprocessed[view_type]["original_image"].copy(), electrodes + ) + + electrodes = electrodes.squeeze(0).copy() + electrodes = np.column_stack( + (electrodes, np.array([None] * electrodes.shape[0], dtype=object)) + ) + self.detected[view_type][DetectionMethod.ELECTRODE_BASIC] = electrodes.copy() + self.preprocessed[view_type]["electrodes_basic"] = electrodes.copy() + self.preprocessed[view_type]["electrodes_basic_image"] = electrodes_image.copy() + + def detect_electrodes( + self, + view_types: Optional[List[ViewType]] = None, + ) -> None: + """ + Detects electrodes in preprocessed images for given view types. + """ + if not self.preprocessed: + raise ValueError("View images are not preprocessed") + + if not view_types: + view_types = [view_type for view_type in ViewType if view_type.marker_cfg is not None] + + logging.info("Detecting electrodes for view types: %s", [vt.name for vt in view_types]) + + detector = ElectrodeDetector() + for view_type in view_types: + electrodes, _, _ = detector.detect( + self.preprocessed[view_type]["preprocessed_image"].copy(), + self.preprocessed[view_type]["foreground_mask"].copy(), + view_type, + ) + electrodes_image = BasicElectrodeDetector.draw_circles( + self.preprocessed[view_type]["original_image"].copy(), electrodes + ) + + electrodes = electrodes.squeeze(0).copy() + electrodes = np.column_stack( + (electrodes, np.array([None] * electrodes.shape[0], dtype=object)) + ) + self.detected[view_type][DetectionMethod.ELECTRODE] = electrodes.copy() + self.preprocessed[view_type]["electrodes"] = electrodes.copy() + self.preprocessed[view_type]["electrodes_image"] = electrodes_image.copy() + + @staticmethod + def detect_markers_per_view( + view_type: ViewType, + methods: List[DetectionMethod], + preprocessed_entry: Dict[str, np.ndarray], + ): + """ + Run marker detection for one view type. + """ + detector = MarkerDetector() + markers, _ = detector.detect( + preprocessed_entry["preprocessed_image"].copy(), + preprocessed_entry["foreground_mask"].copy(), + view_type, + methods, + ) + markers_image = MarkerDetector.draw_circles( + preprocessed_entry["original_image"].copy(), markers + ) + markers = markers.squeeze(0).copy() + markers = np.column_stack((markers, np.array([None] * markers.shape[0], dtype=object))) + return ("marker", view_type, markers, markers_image) + + @staticmethod + def detect_basic_electrodes_per_view( + view_type: ViewType, + methods: List[DetectionMethod], + preprocessed_entry: Dict[str, np.ndarray], + ): + """ + Run basic electrode detection for one view type. + """ + detector = BasicElectrodeDetector() + electrodes, _ = detector.detect( + preprocessed_entry["preprocessed_image"].copy(), + preprocessed_entry["foreground_mask"].copy(), + view_type, + methods, + ) + electrodes_image = BasicElectrodeDetector.draw_circles( + preprocessed_entry["original_image"].copy(), electrodes + ) + electrodes = electrodes.squeeze(0).copy() + electrodes = np.column_stack( + (electrodes, np.array([None] * electrodes.shape[0], dtype=object)) + ) + return ("electrode_basic", view_type, electrodes, electrodes_image) + + @staticmethod + def detect_advanced_electrodes_per_view( + view_type: ViewType, + preprocessed_entry: Dict[str, np.ndarray], + ): + """ + Run advanced electrode detection for one view type. + """ + detector = ElectrodeDetector() + electrodes, _, _ = detector.detect( + preprocessed_entry["preprocessed_image"].copy(), + preprocessed_entry["foreground_mask"].copy(), + view_type, + ) + electrodes_image = BasicElectrodeDetector.draw_circles( + preprocessed_entry["original_image"].copy(), electrodes + ) + electrodes = electrodes.squeeze(0).copy() + electrodes = np.column_stack( + (electrodes, np.array([None] * electrodes.shape[0], dtype=object)) + ) + return ("electrode_advanced", view_type, electrodes, electrodes_image) + + def detect_markers_and_electrodes( + self, + view_types: Optional[List[ViewType]] = None, + methods: Optional[List[DetectionMethod]] = None, + ) -> None: + """ + Detects markers and electrodes in preprocessed images for given view types. + """ + if not self.preprocessed: + raise ValueError("View images are not preprocessed") + + if not view_types: + view_types = [view_type for view_type in ViewType if view_type.marker_cfg is not None] + + if not methods: + methods = [DetectionMethod.TRADITIONAL, DetectionMethod.FRST] + + logging.info( + "Detecting markers and electrodes for view types: %s", [vt.name for vt in view_types] + ) + + with ProcessPoolExecutor() as executor: + futures = {} + + for view_type in view_types: + preprocessed_entry = self.preprocessed[view_type] + + # Submit all three detectors as separate processes per view type + futures[ + executor.submit( + ViewLoader.detect_markers_per_view, + view_type, + methods, + preprocessed_entry, + ) + ] = (view_type, "marker") + + futures[ + executor.submit( + ViewLoader.detect_basic_electrodes_per_view, + view_type, + methods, + preprocessed_entry, + ) + ] = (view_type, "electrode_basic") + + futures[ + executor.submit( + ViewLoader.detect_advanced_electrodes_per_view, + view_type, + preprocessed_entry, + ) + ] = (view_type, "electrode_advanced") + + # Collect results + for future in as_completed(futures): + result_type, view_type, data, image = future.result() + + if result_type == "marker": + self.detected[view_type][DetectionMethod.MARKER] = data.copy() + self.preprocessed[view_type]["markers"] = data.copy() + self.preprocessed[view_type]["markers_image"] = image.copy() + + elif result_type == "electrode_basic": + self.detected[view_type][DetectionMethod.ELECTRODE_BASIC] = data.copy() + self.preprocessed[view_type]["electrodes_basic"] = data.copy() + self.preprocessed[view_type]["electrodes_basic_image"] = image.copy() + + elif result_type == "electrode_advanced": + self.detected[view_type][DetectionMethod.ELECTRODE] = data.copy() + self.preprocessed[view_type]["electrodes"] = data.copy() + self.preprocessed[view_type]["electrodes_image"] = image.copy() + + def visualize( + self, + view_types: Optional[List[ViewType]] = None, + show_markers: bool = True, + show_electrodes: bool = True, + ) -> None: + """ + Visualize original, preprocessed, markers, and electrodes images for each selected view type. Each view type is shown in a new row. + """ + if not self.preprocessed: + raise ValueError("No preprocessed images available.") + + if not view_types: + view_types = list(self.preprocessed.keys()) + + # Build column setup + columns = [ + ("original_image", "Original"), + ("preprocessed_image", "Preprocessed"), + ] + if show_markers: + columns.append(("markers_image", "Markers")) + if show_electrodes: + columns.append(("electrodes_basic_image", "Electrodes (basic)")) + columns.append(("electrodes_image", "Electrodes")) + + n_rows = len(view_types) + n_cols = len(columns) + + fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows)) + + # Ensure axes is 2D for consistent indexing + if n_rows == 1: + axes = [axes] + else: + axes = axes.tolist() + + for row_idx, view_type in enumerate(view_types): + for col_idx, (key, title) in enumerate(columns): + ax = axes[row_idx][col_idx] if n_rows > 1 else axes[0][col_idx] + img = self.preprocessed[view_type].get(key) + + if img is not None: + ax.imshow(img, cmap="gray" if img.ndim == 2 else None) + ax.set_title(f"{view_type.name} - {title}") + else: + ax.set_title(f"{view_type.name} - {title} (missing)") + ax.axis("off") + + plt.tight_layout() + plt.show() diff --git a/src/processing_models/electrode/view_type.py b/src/processing_models/electrode/view_type.py new file mode 100644 index 0000000..4f2a319 --- /dev/null +++ b/src/processing_models/electrode/view_type.py @@ -0,0 +1,286 @@ +from enum import Enum +from dataclasses import dataclass + + +@dataclass +class ElectrodeConfig: + num_electrodes: int + min_area: int + min_distance: int + min_radius: int + max_radius: int + n_segments: int + + +@dataclass +class MarkerConfig: + num_green_markers: int + min_area: int + min_distance: int + min_radius: int + max_radius: int + + +class ViewType(Enum): + """ + Possible 3D head scan view angles, with metadata about electrodes and markers. + """ + + FRONT = ( + "front", + ElectrodeConfig( + num_electrodes=None, + min_area=500, + min_distance=50, + min_radius=15, + max_radius=25, + n_segments=750, + ), + MarkerConfig( + num_green_markers=3, + min_area=1500, + min_distance=150, + min_radius=20, + max_radius=35, + ), # 2 - 1 + ) + BACK = ( + "back", + ElectrodeConfig( + num_electrodes=None, + min_area=500, + min_distance=50, + min_radius=10, + max_radius=25, + n_segments=750, + ), + MarkerConfig( + num_green_markers=0, + min_area=1500, + min_distance=75, + min_radius=20, + max_radius=35, + ), # 0 + ) + TOP = ( + "top", + ElectrodeConfig( + num_electrodes=None, + min_area=500, + min_distance=50, + min_radius=15, + max_radius=25, + n_segments=750, + ), + MarkerConfig( + num_green_markers=0, + min_area=1500, + min_distance=150, + min_radius=20, + max_radius=35, + ), # 0 + ) + RIGHT = ( + "right", + ElectrodeConfig( + num_electrodes=None, + min_area=500, + min_distance=75, + min_radius=15, + max_radius=25, + n_segments=750, + ), + MarkerConfig( + num_green_markers=8, + min_area=1500, + min_distance=100, + min_radius=20, + max_radius=35, + ), # 3 - 2 - 3 + ) + LEFT = ( + "left", + ElectrodeConfig( + num_electrodes=None, + min_area=500, + min_distance=75, + min_radius=15, + max_radius=25, + n_segments=750, + ), + MarkerConfig( + num_green_markers=8, + min_area=1500, + min_distance=100, + min_radius=20, + max_radius=35, + ), # 3 - 2 - 3 + ) + FRONT_TOP = ( + "front_top", + ElectrodeConfig( + num_electrodes=None, + min_area=500, + min_distance=50, + min_radius=10, + max_radius=25, + n_segments=750, + ), + MarkerConfig( + num_green_markers=7, + min_area=1500, + min_distance=150, + min_radius=15, + max_radius=35, + ), # 2 - 3 - (2) + ) + BACK_TOP = ( + "back_top", + ElectrodeConfig( + num_electrodes=None, + min_area=500, + min_distance=50, + min_radius=10, + max_radius=25, + n_segments=750, + ), + MarkerConfig( + num_green_markers=7, + min_area=1500, + min_distance=75, + min_radius=20, + max_radius=35, + ), # 3 - 1 - 2 - 1 + ) + FRONT_RIGHT = ( + "front_right", + ElectrodeConfig( + num_electrodes=None, + min_area=500, + min_distance=75, + min_radius=15, + max_radius=25, + n_segments=750, + ), + MarkerConfig( + num_green_markers=0, + min_area=1500, + min_distance=100, + min_radius=20, + max_radius=35, + ), # 0 + ) + FRONT_LEFT = ( + "front_left", + ElectrodeConfig( + num_electrodes=None, + min_area=500, + min_distance=75, + min_radius=15, + max_radius=25, + n_segments=750, + ), + MarkerConfig( + num_green_markers=0, + min_area=1500, + min_distance=100, + min_radius=20, + max_radius=35, + ), # 0 + ) + BACK_RIGHT = ( + "back_right", + ElectrodeConfig( + num_electrodes=None, + min_area=500, + min_distance=75, + min_radius=15, + max_radius=25, + n_segments=750, + ), + MarkerConfig( + num_green_markers=0, + min_area=1500, + min_distance=100, + min_radius=20, + max_radius=35, + ), # 0 + ) + BACK_LEFT = ( + "back_left", + ElectrodeConfig( + num_electrodes=None, + min_area=500, + min_distance=75, + min_radius=15, + max_radius=25, + n_segments=750, + ), + MarkerConfig( + num_green_markers=0, + min_area=1500, + min_distance=100, + min_radius=20, + max_radius=35, + ), # 0 + ) + TOP_RIGHT = ( + "top_right", + ElectrodeConfig( + num_electrodes=None, + min_area=500, + min_distance=75, + min_radius=15, + max_radius=25, + n_segments=750, + ), + MarkerConfig( + num_green_markers=0, + min_area=1500, + min_distance=100, + min_radius=20, + max_radius=35, + ), # 0 + ) + TOP_LEFT = ( + "top_left", + ElectrodeConfig( + num_electrodes=None, + min_area=500, + min_distance=75, + min_radius=15, + max_radius=25, + n_segments=750, + ), + MarkerConfig( + num_green_markers=0, + min_area=1500, + min_distance=100, + min_radius=20, + max_radius=35, + ), # 0 + ) + + # Not needed views + BOTTOM = ("bottom", None, None) + FRONT_BOTTOM = ("front_bottom", None, None) + BACK_BOTTOM = ("back_bottom", None, None) + BOTTOM_RIGHT = ("bottom_right", None, None) + BOTTOM_LEFT = ("bottom_left", None, None) + + def __init__(self, label: str, electrode_cfg: ElectrodeConfig, marker_cfg: MarkerConfig): + self.label = label + self.electrode_cfg = electrode_cfg + self.marker_cfg = marker_cfg + + def __str__(self) -> str: + return ( + f"{self.label} " + f"(electrodes={self.electrode_cfg.num_electrodes if self.electrode_cfg else 'NA'}, " + f"min_area_e={self.electrode_cfg.min_area if self.electrode_cfg else 'NA'}, " + f"min_dist_e={self.electrode_cfg.min_distance if self.electrode_cfg else 'NA'}, " + f"markers={self.marker_cfg.num_green_markers if self.marker_cfg else 'NA'}, " + f"min_area_m={self.marker_cfg.min_area if self.marker_cfg else 'NA'}, " + f"min_dist_m={self.marker_cfg.min_distance if self.marker_cfg else 'NA'}, " + f"radius=[{self.marker_cfg.min_radius}, {self.marker_cfg.max_radius}]" + ) diff --git a/src/processing_models/mesh/__init__.py b/src/processing_models/mesh/__init__.py new file mode 100644 index 0000000..75696f6 --- /dev/null +++ b/src/processing_models/mesh/__init__.py @@ -0,0 +1,17 @@ +from .cap_extractor import CapExtractor +from .electrode_curvature_detector import ElectrodeCurvatureDetector +from .fiducial_labeler import FiducialLabeler +from .head_capturer import HeadCapturer +from .head_cleaner import HeadCleaner +from .head_pose_aligner import HeadPoseAligner +from .mesh_loader import MeshLoader + +__all__ = [ + "CapExtractor", + "ElectrodeCurvatureDetector", + "FiducialLabeler", + "HeadCapturer", + "HeadCleaner", + "HeadPoseAligner", + "MeshLoader", +] diff --git a/src/processing_models/mesh/cap_extractor.py b/src/processing_models/mesh/cap_extractor.py new file mode 100644 index 0000000..7709a40 --- /dev/null +++ b/src/processing_models/mesh/cap_extractor.py @@ -0,0 +1,126 @@ +import logging +import numpy as np +import pandas as pd + +from typing import List, Optional, Tuple +from vedo import Mesh, Plane, Plotter, Sphere, settings + +from .head_cleaner import HeadCleaner + + +settings.default_backend = "vtk" +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +class CapExtractor: + def __init__( + self, + mesh: Mesh, + fiducials: pd.DataFrame, + texture: Optional[np.ndarray] = None, + ) -> None: + """ + Initialize the cap extractor to isolate the cap portion of a head mesh above the plane. + """ + self.mesh_raw = mesh.clone() + self.mesh = mesh.clone() + self.mesh_cap = None + self.texture = texture.copy() if texture is not None else None + self.fiducials = fiducials.copy() + self.plane = None + + if texture is None: + self.mesh.texture(None).color("gray") + + def _get_fiducials(self, labels: List[str]) -> List[np.ndarray]: + """ + Extract and validate fiducial points. + """ + missing = [lbl for lbl in labels if lbl not in self.fiducials.index] + if missing: + raise ValueError(f"Missing fiducials: {', '.join(missing)}") + + return [self.fiducials.loc[label].to_numpy(dtype=float).copy() for label in labels] + + def _compute_cutting_plane(self) -> Tuple[np.ndarray, np.ndarray]: + """ + Compute the cutting plane defined by NAS-LPA-RPA fiducials. + """ + nas, lpa, rpa = self._get_fiducials(["NAS", "LPA", "RPA"]) + points = np.array([nas, lpa, rpa]) + + center = points.mean(axis=0) + + # Plane normal via SVD (robust to collinearity) + _, _, vh = np.linalg.svd(points - center) + normal = vh[-1] / np.linalg.norm(vh[-1]) + + # Flip normal to point upward + if normal[1] < 0: + normal = -normal + + logging.info(f"Cutting plane: center={center}, normal={normal}") + return center, normal + + def extract_cap(self, margin: float = 0.0) -> Mesh: + """ + Extract the cap portion of the mesh above the cutting plane. + """ + center, normal = self._compute_cutting_plane() + self.plane = {"center": center, "normal": normal} + + adjusted_center = center - margin * normal + + # Cut everything above plane + self.mesh = self.mesh.cut_with_plane(origin=adjusted_center, normal=normal) + + # Keep largest connected component only + self.cleaner = HeadCleaner(self.mesh, self.fiducials, self.texture) + self.cleaner.clean_from_unwanted_objects() + + self.mesh_cap = self.mesh.clone() + return self.mesh_cap + + def _add_plane_visualization(self, plotter: Plotter) -> None: + """ + Add a plane visualization to the plotter. + """ + bounds = self.mesh.bounds() + x_size_temp, z_size_temp = bounds[1] - bounds[0], bounds[5] - bounds[4] + center, normal = self.plane["center"], self.plane["normal"] + + # Adjust plane size using fiducials if available + if {"NAS", "INI", "LPA", "RPA"}.issubset(self.fiducials.index): + nas, ini, lpa, rpa = self._get_fiducials(["NAS", "INI", "LPA", "RPA"]) + x_size_temp = abs(lpa[0] - rpa[0]) + z_size_temp = abs(nas[2] - ini[2]) + center = np.mean([nas, lpa, rpa], axis=0) + + # Make plane roughly square + x_size = x_size_temp * 0.75 + z_size_temp + z_size = z_size_temp * 0.75 + x_size_temp + + plane = Plane(pos=center, normal=normal, s=(z_size, x_size), alpha=0.3).c("red") + plotter.add(plane) + + def visualize_extraction( + self, + show_fiducials: bool = True, + show_plane: bool = True, + ) -> None: + """ + Visualize the extraction process. + """ + plotter = Plotter(title="EEG Head Mesh Processing Result") + plotter.add(self.mesh) + + if show_fiducials: + fiducial_spheres = [ + Sphere(pos=coords, r=2.5, c="red") for coords in self.fiducials.values + ] + plotter.add(fiducial_spheres) + + if show_plane and self.plane is not None: + self._add_plane_visualization(plotter) + + plotter.show() diff --git a/src/processing_models/mesh/electrode_curvature_detector.py b/src/processing_models/mesh/electrode_curvature_detector.py new file mode 100644 index 0000000..e36930a --- /dev/null +++ b/src/processing_models/mesh/electrode_curvature_detector.py @@ -0,0 +1,255 @@ +import logging +import numpy as np + +from vedo import Mesh, settings +from scipy.spatial import KDTree +from typing import Dict, List, Tuple + + +settings.default_backend = "vtk" +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +class ElectrodeCurvatureDetector: + def __init__(self, mesh: Mesh) -> None: + """ + Detect electrode locations using curvature analysis and radial symmetry. + """ + self.mesh_raw = mesh.clone() + self.mesh = mesh.clone() + self.vertices = self.mesh.points().copy() + + # Compute curvature (0 = Gaussian, 1 = Mean) + self.mesh.compute_curvature(method=1) + self.curvatures = self.mesh.pointdata["Mean_Curvature"] + + self.candidates = None + self.filtered_curvatures = None + self.saliency_map = None + self.probability_map = None + self.gradient_directions = None + + def _remove_extreme_curvatures( + self, + outlier_percentile_low: int = 5, + outlier_percentile_high: int = 95, + smoothing_iterations: int = 3, + ) -> None: + """ + Clip extreme curvature values and apply spatial smoothing. + """ + curvatures = self.curvatures.copy() + + # Remove extreme outlier and cap values + low_threshold = np.percentile(curvatures, outlier_percentile_low) + high_threshold = np.percentile(curvatures, outlier_percentile_high) + filtered_curvatures = np.clip(curvatures, low_threshold, high_threshold) + + # Apply spatial smoothing to reduce noise + if smoothing_iterations > 0: + tree = KDTree(self.vertices) + for _ in range(smoothing_iterations): + smoothed = filtered_curvatures.copy() + for i, vertex in enumerate(self.vertices): + neighbor_indices = tree.query_ball_point(vertex, r=3.0) + if len(neighbor_indices) > 1: + smoothed[i] = np.median(filtered_curvatures[neighbor_indices]) + filtered_curvatures = smoothed + + self.filtered_curvatures = filtered_curvatures + + def _compute_local_curvature_std(self, radius: float = 8.0) -> np.ndarray: + """ + Compute local curvature standard deviation for each vertex. + """ + tree = KDTree(self.vertices) + local_stds = np.zeros(len(self.vertices)) + + for i, vertex in enumerate(self.vertices): + neighbor_indices = tree.query_ball_point(vertex, radius) + if len(neighbor_indices) > 3: + neighbor_curvatures = self.curvatures[neighbor_indices] + local_stds[i] = np.std(neighbor_curvatures) + + return local_stds + + def _create_head_region_mask(self) -> np.ndarray: + """ + Create mask focusing on head region while excluding artifacts and high-variance areas. + """ + curvatures = ( + self.filtered_curvatures if self.filtered_curvatures is not None else self.curvatures + ) + + # Start with all vertices + mask = np.ones(len(self.vertices), dtype=bool) + + # Exclude remaining extreme curvature outliers + curvature_threshold = np.percentile(np.abs(curvatures), 85) + mask &= np.abs(curvatures) < curvature_threshold + + # Focus on relatively stable regions + local_curvature_std = self._compute_local_curvature_std(radius=8.0) + curvature_std_threshold = np.percentile(local_curvature_std, 90) + mask &= local_curvature_std < curvature_std_threshold + + logging.info( + "Head region mask includes %d / %d vertices (%.1f%%)", + np.sum(mask), + len(mask), + np.sum(mask) / len(mask) * 100, + ) + return mask + + def _compute_saliency_map( + self, + region_mask: np.ndarray, + neighborhood_radius: float = 7.0, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Compute saliency map, probability map, and gradient directions. + """ + vertices = self.vertices.copy() + curvatures = ( + self.filtered_curvatures if self.filtered_curvatures is not None else self.curvatures + ) + + tree = KDTree(vertices) + n_vertices = len(vertices) + + # Initialize maps + saliency_scores = np.zeros(n_vertices) + probability_map = np.zeros(n_vertices) + gradient_directions = np.zeros((n_vertices, 3)) # 3D gradient directions + + # Only process vertices in the head region + valid_indices = np.where(region_mask)[0] + + for i in valid_indices: + vertex = vertices[i] + current_curvature = curvatures[i] + + # Find local neighborhood + neighbor_indices = tree.query_ball_point(vertex, neighborhood_radius) + neighbor_indices = [idx for idx in neighbor_indices if region_mask[idx]] + + if len(neighbor_indices) < 5: + continue + + neighbor_curvatures = curvatures[neighbor_indices] + neighbor_positions = vertices[neighbor_indices] + + # Compute local saliency based on: + # 1. Local curvature prominence + local_mean = np.mean(neighbor_curvatures) + local_std = np.std(neighbor_curvatures) + + if local_std > 0: + prominence = (current_curvature - local_mean) / local_std + saliency_scores[i] = max(0, prominence) # Only positive prominences + + # 2. Compute gradient direction to local maximum + if len(neighbor_indices) > 1: + # Find direction to highest curvature neighbor + max_neighbor_idx = neighbor_indices[np.argmax(neighbor_curvatures)] + if max_neighbor_idx != i: + direction = vertices[max_neighbor_idx] - vertex + norm = np.linalg.norm(direction) + if norm > 0: + gradient_directions[i] = direction / norm + + # 3. Enhance based on local geometry (bump-like structures) + if len(neighbor_indices) >= 8: + # Check if current points is a local maximum + is_local_max = current_curvature >= np.percentile(neighbor_curvatures, 75) + + # Compute radial symmetry score + radial_positions = neighbor_positions - vertex + distances = np.linalg.norm(radial_positions, axis=1) + if len(distances) > 0 and np.std(distances) > 0: + # Prefer circular/symmetric arrangements + symmetry_score = 1.0 / (1.0 + np.std(distances) / np.mean(distances)) + if is_local_max: + saliency_scores[i] *= 1.0 + symmetry_score + + # Normalize saliency scores + if np.max(saliency_scores) > 0: + saliency_scores /= np.max(saliency_scores) + + # Convert saliency to probablity usign sigmod-like function + probability_map = 1.0 / (1.0 + np.exp(-5 * (saliency_scores - 0.3))) + probability_map[~region_mask] = 0 + saliency_scores[~region_mask] = 0 + gradient_directions[~region_mask] = 0 + + self.saliency_map = saliency_scores + self.probability_map = probability_map + self.gradient_directions = gradient_directions + + return saliency_scores, probability_map, gradient_directions + + def _detect_electrode_candidates( + self, + min_probability: float = 0.3, + min_distance: float = 15.0, + ) -> List[Dict[str, np.ndarray]]: + """ + Detect electrode candidates using non-maximum suppression on probability map. + """ + if self.probability_map is None: + raise ValueError("Must compute saliency map first") + + # Find peaks in probability map + high_prob_indices = np.where(self.probability_map > min_probability)[0] + if len(high_prob_indices) == 0: + return [] + + # Sort by probabiilty + sorted_indices = high_prob_indices[ + np.argsort(self.probability_map[high_prob_indices])[::-1] + ] + + # Apply non-maximum suppression based on distance + candidates = [] + for vertex_idx in sorted_indices: + vertex_pos = self.vertices[vertex_idx] + probability = self.probability_map[vertex_idx] + direction = self.gradient_directions[vertex_idx] + + # Check if too close to existing candidates + too_close = any( + np.linalg.norm(vertex_pos - c["position"]) < min_distance for c in candidates + ) + if not too_close: + candidates.append( + { + "vertex_index": vertex_idx, + "position": vertex_pos, + "probability": probability, + "gradient_direction": direction, + "saliency": self.saliency_map[vertex_idx], + } + ) + + logging.info("Detected %d electrode candidates.", len(candidates)) + self.candidates = candidates + return candidates + + def extract_curvatures(self): + """ + Run the full electrode detection pipeline. + """ + self._remove_extreme_curvatures() + head_mask = self._create_head_region_mask() + saliency, probability, directions = self._compute_saliency_map( + region_mask=head_mask, neighborhood_radius=7 + ) + candidates = self._detect_electrode_candidates() + + return { + "saliency_map": saliency, + "probability_map": probability, + "gradient_directions": directions, + "candidates": candidates, + "head_mask": head_mask, + } diff --git a/src/processing_models/mesh/fiducial_labeler.py b/src/processing_models/mesh/fiducial_labeler.py new file mode 100644 index 0000000..d0d0234 --- /dev/null +++ b/src/processing_models/mesh/fiducial_labeler.py @@ -0,0 +1,121 @@ +import logging + +from dataclasses import dataclass +from typing import List, Optional, Tuple +from vedo import Plotter, Sphere, settings + +from .mesh_loader import MeshLoader + + +settings.default_backend = "vtk" +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +@dataclass +class FiducialInfo: + code: str + name: str + color: str + + +class FiducialLabeler: + def __init__(self, mesh_path: str, texture_path: str) -> None: + """ + Interactive tool for labeling anatomical fiducials on a 3D head mesh. + """ + + # Define standard fiducial labels + self.fiducial_info: List[FiducialInfo] = [ + FiducialInfo("NAS", "Nasion", "red"), + FiducialInfo("LPA", "Left Pre-Auricular", "blue"), + FiducialInfo("RPA", "Right Pre-Auricular", "green"), + FiducialInfo("LHJ", "Left Helix-Tragus Junction", "cyan"), + FiducialInfo("RHJ", "Right Helix-Tragus Junction", "magenta"), + FiducialInfo("INI", "Inion", "yellow"), + FiducialInfo("VTX", "Vertex (Cz approx.)", "orange"), + ] + self._current_index: int = 0 + + # Initialize mesh loader + self.mesh_loader = MeshLoader(mesh_path, texture_path) + self.mesh = self.mesh_loader.mesh + + # Initialize fiducials storage (None until picked) + self.fiducials: List[Optional[Tuple[str, float, float, float]]] = [ + None for _ in self.fiducial_info + ] + + # Initialize plotter + self.plotter = Plotter(title="Fiducial Labeling Tool", size=(1200, 800)) + + def _log_fiducial(self) -> None: + """ + Log which fiducial the user should pick next. + """ + fiducial = self.fiducial_info[self._current_index] + logging.info(f"Pick: {fiducial.name} ({fiducial.code})") + + def _on_key_press(self, event) -> None: + """ + Handle keyboard input events. + - Press 'q' to quit and print fiducials. + """ + if event.keypress == "q" or event.keypress == "s": + logging.info(f"Final fiducials:\n{self.fiducials}") + logging.info("Exiting Fiducial Labeling Tool...") + self.plotter.close() + + def _on_middle_click(self, event) -> None: + """ + Handle middle mouse button clicks to place fiducials. + """ + fiducial = self.fiducial_info[self._current_index] + pos = tuple(map(float, event.picked3d)) + + logging.info(f"Picked {fiducial.name} ({fiducial.code}) at {pos}") + + # Store fiducial + self.fiducials[self._current_index] = (fiducial.code, *pos) + + # Move to next fiducial and log instruction + self._current_index = (self._current_index + 1) % len(self.fiducial_info) + self._log_fiducial() + + # Place a marker sphere + self.plotter.add(Sphere(pos=pos, r=2.5, c=fiducial.color)) + + def run(self) -> None: + """ + Start the interactive labeling tool. + """ + logging.info("Starting Fiducial Labeling Tool...") + if self.mesh: + self.plotter.add(self.mesh) + + # Register event callbacks + self.plotter.add_callback("KeyPress", self._on_key_press) + self.plotter.add_callback("MiddleButtonPress", self._on_middle_click) + + # Log instruction before first click + self._log_fiducial() + + # Show interactive window + self.plotter.show() + + def save(self, file_path: str) -> None: + """ + Save labeled fiducials to a CSV file. + Format: CODE,x,y,z + """ + if self.fiducials is None or any(fid is None for fid in self.fiducials): + logging.warning("Not all fiducials are labeled.") + return + + try: + with open(file_path, "w", encoding="utf-8") as f: + for fid in self.fiducials: + if fid is not None: + f.write(f"{fid[0]},{fid[1]},{fid[2]},{fid[3]}\n") + logging.info(f"Fiducials saved to {file_path}") + except OSError as e: + logging.error(f"Failed to save fiducials: {e}") diff --git a/src/processing_models/mesh/head_capturer.py b/src/processing_models/mesh/head_capturer.py new file mode 100644 index 0000000..55d378f --- /dev/null +++ b/src/processing_models/mesh/head_capturer.py @@ -0,0 +1,437 @@ +import os +import logging +import numpy as np +import pandas as pd + +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union +from vedo import Arrow, Axes, Mesh, Plane, Plotter, Sphere, Text3D, settings + + +settings.default_backend = "vtk" +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +class HeadCapturer: + def __init__( + self, + mesh: Mesh, + fiducials: pd.DataFrame, + center: Optional[np.ndarray] = None, + texture: Optional[np.ndarray] = None, + ) -> None: + """ + Initialize to capture screenshots of a 3D head mesh from multiple camera views. + """ + self.mesh = mesh.clone() + self.fiducials = fiducials.copy() + self.texture = texture.copy() if texture is not None else None + + # Apply mesh appearance + self.mesh.alpha(1.0) + if self.texture is not None: + self.mesh.texture(self.texture) + + # Default center if none provided + self.center = center.copy() if center is not None else np.array([0, 0, 0]) + + @staticmethod + def _setup_camera(center: np.ndarray, camera_distance: Dict[str, float]) -> Dict[str, Any]: + """ + Configure standard orthogonal camera views: front, back, right, left, top, bottom. + """ + dx, dy, dz = ( + camera_distance["x"], + camera_distance["y"], + camera_distance["z"], + ) + cx, cy, cz = center + + return { + "center": center, + "distance": camera_distance, + "views": { + "front": { + "pos": [cx, cy, cz + dz], + "up": [0, 1, 0], + "name": "front", + "description": "Front view (+Z)", + }, + "back": { + "pos": [cx, cy, cz - dz], + "up": [0, 1, 0], + "name": "back", + "description": "Back view (-Z)", + }, + "top": { + "pos": [cx, cy + dy, cz], + "up": [0, 0, -1], + "name": "top", + "description": "Top view (+Y)", + }, + "bottom": { + "pos": [cx, cy - dy, cz], + "up": [0, 0, 1], + "name": "bottom", + "description": "Bottom view (-Y)", + }, + "right": { + "pos": [cx - dx, cy, cz], + "up": [0, 1, 0], + "name": "right", + "description": "Right view (+X)", + }, + "left": { + "pos": [cx + dx, cy, cz], + "up": [0, 1, 0], + "name": "left", + "description": "Left view (-X)", + }, + "front_top": { + "pos": [cx, (cy + dy) * 0.75, (cz + dz) * 0.75], + "up": [0, 1, 0], + "name": "front_top", + "description": "Front-top view (+Z, +Y)", + }, + "front_bottom": { + "pos": [cx, (cy - dy) * 0.75, (cz + dz) * 0.75], + "up": [0, 1, 0], + "name": "front_bottom", + "description": "Front-bottom view (+Z, -Y)", + }, + "back_top": { + "pos": [cx, (cy + dy) * 0.75, (cz - dz) * 0.75], + "up": [0, 1, 0], + "name": "back_top", + "description": "Back-top view (-Z, +Y)", + }, + "back_bottom": { + "pos": [cx, (cy - dy) * 0.75, (cz - dz) * 0.75], + "up": [0, 1, 0], + "name": "back_bottom", + "description": "Back-bottom view (-Z, -Y)", + }, + "front_right": { + "pos": [(cx - dx) * 0.75, cy, (cz + dz) * 0.75], + "up": [0, 1, 0], + "name": "front_right", + "description": "Front-right view (+X, +Z)", + }, + "front_left": { + "pos": [(cx + dx) * 0.75, cy, (cz + dz) * 0.75], + "up": [0, 1, 0], + "name": "front_left", + "description": "Front-left view (-X, +Z)", + }, + "back_right": { + "pos": [(cx - dx) * 0.75, cy, (cz - dz) * 0.75], + "up": [0, 1, 0], + "name": "back_right", + "description": "Back-right view (+X, -Z)", + }, + "back_left": { + "pos": [(cx + dx) * 0.75, cy, (cz - dz) * 0.75], + "up": [0, 1, 0], + "name": "back_left", + "description": "Back-left view (-X, -Z)", + }, + "top_right": { + "pos": [(cx - dx) * 0.75, (cy + dy) * 0.75, cz], + "up": [0, 1, 0], + "name": "top_right", + "description": "Top-right view (+X, +Y)", + }, + "top_left": { + "pos": [(cx + dx) * 0.75, (cy + dy) * 0.75, cz], + "up": [0, 1, 0], + "name": "top_left", + "description": "Top-left view (-X, +Y)", + }, + "bottom_right": { + "pos": [(cx - dx) * 0.75, (cy - dy) * 0.75, cz], + "up": [0, 1, 0], + "name": "top_right", + "description": "Bottom-right view (+X, -Y)", + }, + "bottom_left": { + "pos": [(cx + dx) * 0.75, (cy - dy) * 0.75, cz], + "up": [0, 1, 0], + "name": "top_left", + "description": "Bottom-left view (-X, -Y)", + }, + }, + } + + def capture_all_views( + self, + output_dir: Union[str, Path], + image_size: Tuple[int, int] = (1024, 1024), + show_cap: Optional[bool] = False, + cmap: np.ndarray = None, + ) -> None: + """ + Capture screenshots from all 6 predefined camera views. + """ + # Estimate a reasonable camera distance from mesh bounds + center = self.center.copy() + bounds = self.mesh.bounds() + camera_distance = { + "x": (bounds[1] - bounds[0]) * 2.5, + "y": (bounds[3] - bounds[2]) * 1.75, + "z": (bounds[5] - bounds[4]) * 2.5, + } + + # Define a custom camera position based on fiducials (for cap views) + if show_cap and self.fiducials is not None: + required_labels = ["NAS", "LPA", "RPA"] + + if set(required_labels).issubset(self.fiducials.index): + # Extract fiducial coordinates + fiducials = { + label: self.fiducials.loc[label].to_numpy(dtype=float) + for label in required_labels + } + + nas, lpa, rpa = fiducials["NAS"], fiducials["LPA"], fiducials["RPA"] + plane_center = np.mean([nas, lpa, rpa], axis=0) + + # Update center.y to position the camera above the cropped mesh + center_y = (bounds[3] - plane_center[2]) / 2.0 + center[1] = center_y + + # Set camera distance in y direction (scaled by bounding box extent) + bbox_y_extent = bounds[3] - bounds[2] + camera_distance["y"] = bbox_y_extent * 2.75 + + # Setup camera positions + camera_config = HeadCapturer._setup_camera(center, camera_distance) + + # Custom color map + if cmap is not None: + self.mesh.cmap("turbo", cmap) + + os.makedirs(output_dir, exist_ok=True) + screenshot_paths: List[str] = [] + + plotter = Plotter(offscreen=True, interactive=False, size=image_size) + + for view_name, view in camera_config["views"].items(): + logging.info(f"Capturing {view['description']}...") + + # Reset scene + plotter.clear() + plotter.add(self.mesh) + + # Camera setup + plotter.camera.SetPosition(view["pos"]) + plotter.camera.SetFocalPoint(camera_config["center"]) + plotter.camera.SetViewUp(view["up"]) + + plotter.background((255, 255, 255)) + plotter.render() + + # Save screenshot + screenshot_path = os.path.join(output_dir, f"{view_name}.png") + plotter.screenshot(screenshot_path) + screenshot_paths.append(screenshot_path) + + # plotter.close() + logging.info(f"All screenshots saved in: {output_dir}") + + @staticmethod + def capture_single_view( + mesh: Mesh, + name: str = None, + output_dir: Union[str, Path] = None, + center: np.ndarray = np.array([0.0, 0.0, 0.0]), + image_size: Tuple[int, int] = (1024, 1024), + mesh_alpha: Optional[float] = 1.0, + mesh_color: Optional[str] = None, + fiducials: Optional[pd.DataFrame] = None, + bounding_box: Optional[Dict[str, Any]] = None, + coordinate_vectors: Optional[Dict[str, np.ndarray]] = None, + cap_plane: Optional[Dict[str, Any]] = None, + cmap: Optional[np.ndarray] = None, + curvatures: Optional[List[np.ndarray]] = None, + show_axes: Optional[bool] = False, + electrodes: Optional[pd.DataFrame] = None, + ) -> str: + """ + Capture a screenshot from a custom single view. + """ + mesh_copy = mesh.clone() + bounds = mesh_copy.bounds() + + # Custom view + view = { + "pos": [ + center[0] - (bounds[1] - bounds[0]) * 1.25, + center[1] + (bounds[3] - bounds[2]) * 1.25, + center[2] + (bounds[5] - bounds[4]) * 3.5, + ], + "up": [0, 1, 0], + } + + # Custom color map + if cmap is not None: + mesh_copy.cmap("turbo", cmap) + + plotter = Plotter(offscreen=True, interactive=False, size=image_size) + + # Mesh appearance + plotter.add(mesh_copy) + mesh_copy.alpha(mesh_alpha) + if mesh_color: + mesh_copy.color(mesh_color) + + # Fiducials + if fiducials is not None: + spheres = [Sphere(pos=f, r=6, c="red") for f in fiducials.values] + plotter.add(spheres) + # labels = [ + # Text3D( + # l, + # pos=f + np.array([5, 5, 5]), + # s=8.5, + # depth=0.45, + # c="black", + # font="Theemim", + # ) + # for l, f in fiducials.iterrows() + # ] + # plotter.add(labels) + + # Electrodes + if electrodes is not None: + + def get_color(index): + if index == "MARKER": + return "green" + elif index == "ELECTRODE": + return "red" + elif index == "ELECTRODE_BASIC": + return "blue" + return "black" + + spheres = [Sphere(pos=f, r=4, c=get_color(idx)) for idx, f in electrodes.iterrows()] + plotter.add(spheres) + + # Bounding box planes + if bounding_box: + x_size, y_size, z_size = bounding_box["size"] + margin = 0.025 + faces = [ + ( + [0, 1, 0], + [0, y_size / 2, 0], + (z_size * (1 - margin), x_size * (1 - margin)), + ), + ( + [0, -1, 0], + [0, -y_size / 2, 0], + (z_size * (1 - margin), x_size * (1 - margin)), + ), + ( + [1, 0, 0], + [x_size / 2, 0, 0], + (z_size * (1 - margin), y_size * (1 - margin)), + ), + ( + [-1, 0, 0], + [-x_size / 2, 0, 0], + (z_size * (1 - margin), y_size * (1 - margin)), + ), + ( + [0, 0, 1], + [0, 0, z_size / 2], + (x_size * (1 - margin), y_size * (1 - margin)), + ), + ( + [0, 0, -1], + [0, 0, -z_size / 2], + (x_size * (1 - margin), y_size * (1 - margin)), + ), + ] + planes = [ + Plane(pos=center + np.array(offset), normal=normal, s=size) + .alpha(0.25) + .color("gray") + for normal, offset, size in faces + ] + plotter.add(planes) + + # Coordinate vectors + if coordinate_vectors: + plotter.add(Sphere(pos=center, r=10, c="black")) + plotter.add(Arrow(center, center + coordinate_vectors["x"] * 90.0, c="red", s=0.75)) + plotter.add(Arrow(center, center + coordinate_vectors["y"] * 75.0, c="green", s=0.75)) + plotter.add(Arrow(center, center + coordinate_vectors["z"] * 75.0, c="blue", s=0.75)) + + # Show cap plane + if cap_plane is not None and fiducials is not None: + default_x_size = bounds[1] - bounds[0] + default_z_size = bounds[5] - bounds[4] + center, normal = cap_plane["center"].copy(), cap_plane["normal"] + + plane_center = center.copy() # fallback if fiducials not available + x_size_temp, z_size_temp = default_x_size, default_z_size + + required_labels = ["NAS", "INI", "LPA", "RPA"] + if set(required_labels).issubset(fiducials.index): + # Extract fiducial coordinates safely + fids = { + label: fiducials.loc[label].to_numpy(dtype=float) for label in required_labels + } + + nas, ini, lpa, rpa = fids["NAS"], fids["INI"], fids["LPA"], fids["RPA"] + + # Plane dimensions based on fiducials + x_size_temp = abs(lpa[0] - rpa[0]) + z_size_temp = abs(nas[2] - ini[2]) + + # Plane center as midpoints + plane_center = np.mean([nas, lpa, rpa], axis=0) + + # Adjust center.y to account for cropped mesh height + center[1] = (bounds[3] - plane_center[2]) / 2.0 + + # Scale factors for making plane roughly square + square_scale = 0.75 + x_size = x_size_temp * square_scale + z_size_temp + z_size = z_size_temp * square_scale + x_size_temp + + # Build and add visualization plane + plane = Plane(pos=plane_center, normal=normal, s=(z_size, x_size), alpha=0.3).c("red") + plotter.add(plane) + + # Electrode candidates based on curvature + if curvatures: + spheres = [Sphere(pos=c, r=1, c="red") for c in curvatures] + plotter.add(spheres) + + # Global axes + if show_axes: + plotter.add(Axes(mesh_copy, xtitle="X", ytitle="Y", ztitle="Z")) + + # Camera setup + plotter.camera.SetPosition(view["pos"]) + plotter.camera.SetFocalPoint(center) + plotter.camera.SetViewUp(view["up"]) + + plotter.background((255, 255, 255)) + plotter.render() + + if name and output_dir: + # Save screenshot + os.makedirs(output_dir, exist_ok=True) + screenshot_path = os.path.join(output_dir, f"{name}.png") + plotter.screenshot(screenshot_path) + # plotter.close() + + logging.info(f"Captured screenshot: {screenshot_path}") + return screenshot_path + + # Return image + image = plotter.screenshot(filename=None, asarray=True) + # plotter.close() + return image diff --git a/src/processing_models/mesh/head_cleaner.py b/src/processing_models/mesh/head_cleaner.py new file mode 100644 index 0000000..db9bddf --- /dev/null +++ b/src/processing_models/mesh/head_cleaner.py @@ -0,0 +1,326 @@ +import logging +import numpy as np +import pandas as pd + +from vtk import vtkPolyDataConnectivityFilter +from typing import Any, Dict, List, Optional, Union +from vedo import Box, Mesh, Plane, Plotter, Sphere, settings + + +settings.default_backend = "vtk" +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +class HeadCleaner: + def __init__( + self, + mesh: Mesh, + fiducials: pd.DataFrame, + texture: Optional[np.ndarray] = None, + ) -> None: + """ + Initialize the head cleaning processor. + """ + self.mesh_raw = mesh.clone() + self.fiducials = fiducials.copy() + self.texture = texture.copy() if texture is not None else None + + self.mesh_cropped = None + self.mesh_cleaned = None + self.bounding_box = None + + def _get_fiducials(self, required_list) -> List[np.ndarray]: + """ + Extract fiducials and validate they exist. + """ + for fid in required_list: + if fid not in self.fiducials.index: + raise ValueError(f"Required fiducial {fid} not found") + + return [self.fiducials.loc[fid].to_numpy().astype(float).copy() for fid in required_list] + + def crop_with_bounding_box( + self, + x_margin: float = 0.5, + y_top_margin: float = 0.25, + y_bottom_margin: float = 1.0, + z_margin: float = 0.25, + ) -> Dict[str, Any]: + """ + Create a bounding box based on fiducials with margins and crop the mesh. + + x: left-right -> symmetric margin (fraction of size) + y: bottom-up -> top and bottom handled separately (larger bottom margin) + z: back-forward -> symmetric margin + """ + if self.mesh_raw is None: + raise ValueError("Mesh is not loaded") + + # Get required fiducials + required_fids = ["NAS", "LPA", "RPA", "INI", "VTX"] + nas, lpa, rpa, ini, vtx = self._get_fiducials(required_fids) + fiducials = np.stack([nas, lpa, rpa, ini, vtx], axis=0) + + # Compute min/max along each axis + mins, maxs = fiducials.min(axis=0), fiducials.max(axis=0) + x_min, y_min, z_min = mins + x_max, y_max, z_max = maxs + + # Compute ranges + x_range, y_range, z_range = maxs - mins + + # Apply margins + x_min_final = x_min - x_margin * x_range + x_max_final = x_max + x_margin * x_range + y_min_final = y_min - y_bottom_margin * y_range + y_max_final = y_max + y_top_margin * y_range + z_min_final = z_min - z_margin * z_range + z_max_final = z_max + z_margin * z_range + + # Store bounding box info + self.bounding_box = { + "bounds": [ + float(x_min_final), + float(x_max_final), + float(y_min_final), + float(y_max_final), + float(z_min_final), + float(z_max_final), + ], + "center": np.array( + [ + (x_min_final + x_max_final) / 2, + (y_min_final + y_max_final) / 2, + (z_min_final + z_max_final) / 2, + ] + ), + "size": ( + x_max_final - x_min_final, + y_max_final - y_min_final, + z_max_final - z_min_final, + ), + } + + # Get old bounds and position + old_bounds = self.mesh_raw.bounds() + pos = np.array(self.mesh_raw.center_of_mass()) + + # Convert bounds to local coordinates + x_min_local, x_max_local, y_min_local, y_max_local, z_min_local, z_max_local = old_bounds + x_min_local, y_min_local, z_min_local = ( + np.array([x_min_local, y_min_local, z_min_local]) - pos + ) + x_max_local, y_max_local, z_max_local = ( + np.array([x_max_local, y_max_local, z_max_local]) - pos + ) + + # New bounds (from bounding_box) + x_min_final, x_max_final, y_min_final, y_max_final, z_min_final, z_max_final = ( + self.bounding_box["bounds"] + ) + x_min_final, y_min_final, z_min_final = ( + np.array([x_min_final, y_min_final, z_min_final]) - pos + ) + x_max_final, y_max_final, z_max_final = ( + np.array([x_max_final, y_max_final, z_max_final]) - pos + ) + + # Calculate dimensions + dx = x_max_local - x_min_local + dy = y_max_local - y_min_local + dz = z_max_local - z_min_local + + # Calculate clipping proportions + left = (x_min_final - x_min_local) / dx if dx != 0 else 0 + right = (x_max_local - x_max_final) / dx if dx != 0 else 0 + back = (y_min_final - y_min_local) / dy if dy != 0 else 0 + front = (y_max_local - y_max_final) / dy if dy != 0 else 0 + bottom = (z_min_final - z_min_local) / dz if dz != 0 else 0 + top = (z_max_local - z_max_final) / dz if dz != 0 else 0 + + # Clamp to [0, 1] range + left = max(0, min(1, left)) + right = max(0, min(1, right)) + back = max(0, min(1, back)) + front = max(0, min(1, front)) + bottom = max(0, min(1, bottom)) + top = max(0, min(1, top)) + + # Crop the mesh + self.mesh_cropped = self.mesh_raw.clone() + self.mesh_cropped.crop( + left=left, + right=right, + back=back, + front=front, + bottom=bottom, + top=top, + ) + if self.texture is not None: + self.mesh_cropped.texture(self.texture) + + logging.info(f"Bounds before crop: {old_bounds}") + logging.info(f"Bounds after crop: {self.bounding_box['bounds']}") + logging.info( + f"Cropped mesh: {self.mesh_raw.npoints} -> {self.mesh_cropped.npoints} ({(1 - self.mesh_cropped.npoints / self.mesh_raw.npoints) * 100:.1f}% reduction)" + ) + + return self.bounding_box + + def clean_from_unwanted_objects(self) -> Mesh: + """ + Remove unwanted objects by keeping only the largest connected component. + """ + if self.mesh_cropped is None: + if self.mesh_raw is None: + raise ValueError("Mesh is not loaded") + logging.warning("No cropped mesh found, falling back to raw mesh") + self.mesh_cropped = self.mesh_raw.clone() + + connectivity_filter = vtkPolyDataConnectivityFilter() + connectivity_filter.SetInputData(self.mesh_cropped.polydata()) + connectivity_filter.SetExtractionModeToLargestRegion() + connectivity_filter.Update() + + self.mesh_cleaned = Mesh(connectivity_filter.GetOutput()) + if self.texture is not None: + self.mesh_cleaned.texture(self.texture) + + logging.info("Cleaned mesh: kept largest connected component") + return self.mesh_cleaned + + def get_cleaned_mesh(self) -> Mesh: + """ + Return the cleaned mesh. + """ + if self.mesh_cleaned is not None: + return self.mesh_cleaned + logging.warning("No cleaned mesh available, returning cropped mesh") + if self.mesh_cropped is not None: + return self.mesh_cropped + logging.error("No cropped mesh available, returning raw mesh") + if self.mesh_raw is not None: + return self.mesh_raw + raise ValueError("No mesh data available for cleaning") + + def get_summary(self) -> Dict[str, Union[int, float, str]]: + """ + Return summary of processing including mesh sizes, bounding box, and cleaning results. + """ + summary = { + "raw_points": self.mesh_raw.npoints if self.mesh_raw else None, + "cropped_points": self.mesh_cropped.npoints if self.mesh_cropped else None, + "cleaned_points": self.mesh_cleaned.npoints if self.mesh_cleaned else None, + "bounding_box": self.bounding_box if self.bounding_box else None, + } + + if self.mesh_raw and self.mesh_cropped: + summary["crop_reduction_pct"] = 100 * ( + 1 - self.mesh_cropped.npoints / self.mesh_raw.npoints + ) + if self.mesh_cropped and self.mesh_cleaned: + summary["clean_reduction_pct"] = 100 * ( + 1 - self.mesh_cleaned.npoints / self.mesh_cropped.npoints + ) + + return summary + + def _create_box_planes( + self, center: np.ndarray, x_size: float, y_size: float, z_size: float + ) -> List[Plane]: + """ + Create the 6 planes representing the bounding box faces. + """ + # Apply small margin for visualization aesthetics + margin = 0.025 + + # Define the 6 faces of the box + faces = [ + ( + "top", + [0, 1, 0], + [0, y_size / 2, 0], + (z_size * (1 - margin), x_size * (1 - margin)), + ), # +Y face + ( + "bottom", + [0, -1, 0], + [0, -y_size / 2, 0], + (z_size * (1 - margin), x_size * (1 - margin)), + ), # -Y face + ( + "right", + [1, 0, 0], + [x_size / 2, 0, 0], + (z_size * (1 - margin), y_size * (1 - margin)), + ), # +X face + ( + "left", + [-1, 0, 0], + [-x_size / 2, 0, 0], + (z_size * (1 - margin), y_size * (1 - margin)), + ), # -X face + ( + "front", + [0, 0, 1], + [0, 0, z_size / 2], + (x_size * (1 - margin), y_size * (1 - margin)), + ), # +Z face + ( + "back", + [0, 0, -1], + [0, 0, -z_size / 2], + (x_size * (1 - margin), y_size * (1 - margin)), + ), # -Z face + ] + + planes: List[Plane] = [] + for _, normal, offset, size in faces: + plane_center = center + np.array(offset) + plane = Plane(pos=plane_center, normal=normal, s=size) + plane.alpha(0.25).color("gray") + planes.append(plane) + + return planes + + def visualize_result( + self, + show_fiducials: bool = True, + show_box_planes: bool = True, + show_box_mesh: bool = False, + ) -> None: + """ + Visualize the processed mesh with fiducials and bounding box. + """ + # Pick best available mesh (cleaned > cropped > raw) + mesh = self.mesh_cleaned or self.mesh_cropped or self.mesh_raw + if mesh is None: + raise ValueError("No mesh available for visualization") + + plotter = Plotter(title="EEG Head Mesh Processing Result") + plotter.add(mesh) + + # Add fiducial points + if show_fiducials and self.fiducials is not None: + fiducial_spheres = [ + Sphere(pos=fiducial, r=2.5, c="red") for fiducial in self.fiducials.values + ] + plotter.add(fiducial_spheres) + + # Add bounding box + if self.bounding_box is None and (show_box_planes or show_box_mesh): + logging.warning("Bounding box not defined, cannot display box visualization") + elif show_box_planes: + planes = self._create_box_planes( + self.bounding_box["center"], *self.bounding_box["size"] + ) + plotter.add(planes) + elif show_box_mesh: + box_mesh = ( + Box(pos=self.bounding_box["center"], size=self.bounding_box["size"]) + .alpha(0.25) + .color("gray") + ) + plotter.add(box_mesh) + + plotter.show() diff --git a/src/processing_models/mesh/head_pose_aligner.py b/src/processing_models/mesh/head_pose_aligner.py new file mode 100644 index 0000000..039a4d4 --- /dev/null +++ b/src/processing_models/mesh/head_pose_aligner.py @@ -0,0 +1,258 @@ +import logging +import numpy as np +import pandas as pd + +from typing import Dict, List, Optional, Tuple +from vedo import Arrow, Axes, Mesh, Plotter, Sphere, settings + + +settings.default_backend = "vtk" +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +class HeadPoseAligner: + def __init__( + self, + mesh: Mesh, + fiducials: pd.DataFrame, + texture: Optional[np.ndarray] = None, + ) -> None: + """ + Initialize the head pose aligner. + """ + self.mesh = mesh.clone() + self.fiducials = fiducials.copy() + self.texture = texture.copy() if texture is not None else None + + self.origin = None + self.source_origin = None + self.rotation_matrix = None + self.coordinate_vectors = None + + def _get_fiducials(self, required_list) -> List[np.ndarray]: + """ + Extract fiducials and validate they exist. + """ + for fid in required_list: + if fid not in self.fiducials.index: + raise ValueError(f"Required fiducial {fid} not found") + + return [self.fiducials.loc[fid].to_numpy().astype(float).copy() for fid in required_list] + + def _calculate_origin(self) -> np.ndarray: + """ + Calculate head center as centroid of key anatomical landmarks. + """ + required_for_origin = ["NAS", "LPA", "RPA", "INI"] + nas, lpa, rpa, ini = self._get_fiducials(required_for_origin) + return (nas + lpa + rpa + ini) / 4.0 + + @staticmethod + def _rodrigues_rotation_matrix(axis, angle): + """ + Create rotation matrix using Rodrigues' rotation formula. + """ + cos_angle = np.cos(angle) + sin_angle = np.sin(angle) + + # Cross-product matrix for rotation axis + K = np.array([[0, -axis[2], axis[1]], [axis[2], 0, -axis[0]], [-axis[1], axis[0], 0]]) + + # Rodrigues' formula: R = I + sin(θ)K + (1-cos(θ))K² + rotation_matrix = np.eye(3) + sin_angle * K + (1 - cos_angle) * np.dot(K, K) + + return rotation_matrix + + def _calculate_orientation_matrix( + self, + ) -> Tuple[np.ndarray, np.ndarray, Dict[str, np.ndarray]]: + """ + Compute a rotation matrix that orients the head into a standardized anatomical coordinate system facing forward and horizontally aligned. + Standard coordinate system after transformation: + - X-axis: Left to Right (positive X = right) + - Y-axis: Bottom to Top (positive Y = up) + - Z-axis: Back to Front (positive Z = forward/anterior) + """ + # Validate required fiducials + required = ["NAS", "INI", "LPA", "RPA"] + nas, ini, lpa, rpa = self._get_fiducials(required) + + # Calculate origin (centroid of key fiducials) + origin = self._calculate_origin() + logging.info(f"Origin: {origin}") + + # Step 1: Align face to look forward (Z-axis) + # Calculate the INI to NAS vector (current face direction) + # face_direction = nas - ini + # face_direction = face_direction / np.linalg.norm(face_direction) + plane_center = np.mean([lpa, rpa], axis=0) + face_direction = nas - [ini[0], plane_center[1], ini[2]] + face_direction = face_direction / np.linalg.norm(face_direction) + + # Target direction is positive z-axis (forward) + target_z = np.array([0, 0, 1]) + + # Calculate rotation matrix to align face_direction with z-axis + if np.allclose(face_direction, target_z): + # Already aligned + rotation_matrix_1 = np.eye(3) + elif np.allclose(face_direction, -target_z): + # Opposite direction, need 180-degree rotation + if abs(face_direction[2]) > 0.9: + rotation_axis = np.array([0, 1, 0]) + else: + rotation_axis = np.array([1, 0, 0]) + angle = np.pi + rotation_matrix_1 = self._rodrigues_rotation_matrix(rotation_axis, angle) + else: + # General case: calculate rotation axis and angle + rotation_axis = np.cross(face_direction, target_z) + rotation_axis = rotation_axis / np.linalg.norm(rotation_axis) + angle = np.arccos(np.clip(np.dot(face_direction, target_z), -1.0, 1.0)) + rotation_matrix_1 = self._rodrigues_rotation_matrix(rotation_axis, angle) + + # Step 2: Align horizontally (Y-axis) + # Apply first rotation to the horizontal reference points + left_rotated = rotation_matrix_1 @ (lpa - origin) + right_rotated = rotation_matrix_1 @ (rpa - origin) + + # Calculate the Y-difference between left and right points + y_diff = left_rotated[1] - right_rotated[1] + + # Calculate the horizontal distance between points (X-Z plane) + horizontal_distance = np.sqrt( + (left_rotated[0] - right_rotated[0]) ** 2 + (left_rotated[2] - right_rotated[2]) ** 2 + ) + + # Check if points are too close horizontally (degenerate case) + if horizontal_distance < 1e-6: + logging.warning( + "Left and right reference points are vertically aligned. Skipping horizontal alignment." + ) + rotation_matrix_2 = np.eye(3) + else: + # Calculate the tilt angle around Z-axis + # We want to rotate so that the Y-difference becomes zero + tilt_angle = np.arctan2(y_diff, horizontal_distance) + + # Create rotation matrix around Z-axis to remove the tilt + # Negative angle because we want to counter-rotate the tilt + rotation_matrix_2 = self._rodrigues_rotation_matrix(np.array([0, 0, 1]), -tilt_angle) + + # Combine both rotations: first forward alignment, then horizontal alignment + rotation_matrix = rotation_matrix_2 @ rotation_matrix_1 + + # Calculate the new coordinate system vectors after rotation + x_axis = rotation_matrix @ np.array([1, 0, 0]) # Right direction + y_axis = rotation_matrix @ np.array([0, 1, 0]) # Up direction + z_axis = rotation_matrix @ np.array([0, 0, 1]) # Forward direction + + # Create coordinate vectors dictionary + coordinate_vectors = { + "x": x_axis, + "y": y_axis, + "z": z_axis, + } + + return origin, coordinate_vectors, rotation_matrix + + def visualize_orientation( + self, + mesh_alpha: float = 0.25, + ): + """ + Visualize oriented head mesh with fiducials and coordinate axes. + """ + plotter = Plotter(title="Head Orientation", size=(1200, 800)) + + # Mesh appearance + self.mesh.alpha(mesh_alpha) + if self.texture is None: + self.mesh.color("black").alpha(0.25) + plotter.add(self.mesh) + + # Fiducials + if self.fiducials is not None: + fiducial_spheres = [ + Sphere(pos=fiducial, r=2.5, c="red") for fiducial in self.fiducials.values + ] + plotter.add(fiducial_spheres) + + if mesh_alpha < 1.0 and self.origin is not None and self.coordinate_vectors is not None: + # Mark origin + plotter.add(Sphere(pos=self.origin, r=10, c="black")) + + # Coordinate system arrows + plotter.add( + Arrow( + self.origin, + self.origin + self.coordinate_vectors["x"] * 75.0, + c="red", + s=0.5, + ) + ) + plotter.add( + Arrow( + self.origin, + self.origin + self.coordinate_vectors["y"] * 75.0, + c="green", + s=0.5, + ) + ) + plotter.add( + Arrow( + self.origin, + self.origin + self.coordinate_vectors["z"] * 75.0, + c="blue", + s=0.5, + ) + ) + + # Global axes + axes_actor = Axes(self.mesh, xtitle="X (Right)", ytitle="Y (Up)", ztitle="Z (Forward)") + plotter.add(axes_actor) + + plotter.show() + + # Reset mesh alpha + self.mesh.alpha(1.0) + + def orient_head_mesh( + self, + show_result: bool = False, + ) -> Tuple[Mesh, pd.DataFrame]: + """ + Orient the head mesh and fiducials into standard coordinate space. + """ + if self.mesh is None: + raise ValueError("Mesh is not loaded") + if self.fiducials is None or self.fiducials.empty: + raise ValueError("Fiducials are not loaded") + + # Calculate orientation + self.source_origin, self.coordinate_vectors, self.rotation_matrix = ( + self._calculate_orientation_matrix() + ) + + # Apply transformation to mesh (translate then rotate) + mesh_vertices = self.mesh.points().copy() + transformed_vertices = (mesh_vertices - self.source_origin) @ self.rotation_matrix.T + self.mesh.points(transformed_vertices) + + # Apply same transformation to fiducials + fiducial_coords = self.fiducials.to_numpy() + transformed_fiducials = (fiducial_coords - self.source_origin) @ self.rotation_matrix.T + self.fiducials = pd.DataFrame( + transformed_fiducials, + index=self.fiducials.index, + columns=self.fiducials.columns, + ) + + # Update origin + self.origin = self._calculate_origin() + + # Optional visualization + if show_result: + self.visualize_orientation() + + return self.mesh, self.fiducials diff --git a/src/processing_models/mesh/mesh_loader.py b/src/processing_models/mesh/mesh_loader.py new file mode 100644 index 0000000..70e4d0b --- /dev/null +++ b/src/processing_models/mesh/mesh_loader.py @@ -0,0 +1,408 @@ +import os +import logging +import numpy as np +import pandas as pd + +from PIL import Image +from pathlib import Path +from vedo import Axes, Mesh, Plotter, Sphere, settings +from typing import Dict, Literal, Optional, Union, Tuple + +from .head_cleaner import HeadCleaner +from .cap_extractor import CapExtractor +from .head_capturer import HeadCapturer +from .head_pose_aligner import HeadPoseAligner +from .electrode_curvature_detector import ElectrodeCurvatureDetector + + +settings.default_backend = "vtk" +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +class MeshLoader: + def __init__( + self, + mesh_path: Union[str, Path], + texture_path: Optional[Union[str, Path]] = None, + fiducials_path: Optional[Union[str, Path]] = None, + ) -> None: + """ + Initialize the 3D head scanned mesh loader. + """ + self.mesh = self._load_mesh(mesh_path) + self.mesh_raw = self.mesh.clone() if self.mesh is not None else None + self.mesh_preprocessed = self.mesh.clone() if self.mesh is not None else None + self.mesh_cleaned = self.mesh.clone() if self.mesh is not None else None + self.mesh_extracted = self.mesh.clone() if self.mesh is not None else None + self.texture = self._load_texture(texture_path) if texture_path is not None else None + self.fiducials = self._load_fiducials(fiducials_path) + self.preprocessed = False + self.curvatures = None + + self.aligner = None + self.capturer = None + self.cleaner = None + self.cap_extractor = None + self.curvature_extractor = None + + self.preprocess_data() + + def _load_mesh(self, mesh_path: Union[str, Path]) -> Mesh: + """ + Load a 3D mesh from a file. + """ + if not os.path.exists(mesh_path): + logging.error(f"Mesh file not found: {mesh_path}") + return None + logging.info(f"Loading mesh from {mesh_path}") + mesh = Mesh(mesh_path) + return mesh + + def _load_texture(self, texture_path: Union[str, Path]) -> np.ndarray: + """ + Load a texture/image for mesh mapping from a file. + """ + if not os.path.exists(texture_path): + logging.error(f"Texture file not found: {texture_path}") + return None + logging.info(f"Loading texture from {texture_path}") + image = Image.open(texture_path) + texture = np.array(image) + self._apply_texture(texture) + return texture + + def _load_fiducials(self, fiducials_path: Union[str, Path]) -> pd.DataFrame: + """ + Load fiducial points from a CSV file. + - Fiducial codes are indices. + """ + if fiducials_path is None: + logging.warning("Fiducials file not provided") + return None + if not os.path.exists(fiducials_path): + logging.error(f"Fiducials file not found: {fiducials_path}") + return None + + fiducials = pd.read_csv(fiducials_path, header=None, index_col=0) + return fiducials + + def _apply_texture(self, texture: np.ndarray) -> None: + """ + Apply the loaded texture to the mesh. + """ + if self.mesh is not None and texture is not None: + logging.info("Applying texture to the mesh...") + self.mesh.texture(texture) + + def _compute_scale_factor(self) -> float: + """ + Estimate the scale factor needed to convert the mesh into millimeter units. + """ + if self.fiducials is not None: + fiducial_pairs = [ + ("NAS", "INI"), + ("LPA", "RPA"), + ("NAS", "LPA"), + ("NAS", "RPA"), + ] + + # Compute all available fiducial distances + distances = [ + np.linalg.norm(self.fiducials.loc[p1] - self.fiducials.loc[p2]) + for p1, p2 in fiducial_pairs + if p1 in self.fiducials.index and p2 in self.fiducials.index + ] + distance = max(distances) if distances else None + + # Use mesh bounds + bounds = self.mesh.bounds() + width = bounds[1] - bounds[0] + height = bounds[3] - bounds[2] + depth = bounds[5] - bounds[4] + distance = max(width, height, depth) + + # Detect units based on typical head dimensions + if distance < 1: # meters -> millimeters + return 1000.0 + elif distance < 100: # centimeters -> millimeters + return 10.0 + else: # already millimeters + return 1.0 + + def _convert_to_mm(self) -> None: + """ + Convert the mesh to millimeter scale. + """ + if self.mesh is None: + raise ValueError("Mesh is not loaded") + + scale_factor = self._compute_scale_factor() + logging.info(f"Scale factor for conversion: {scale_factor}") + + if scale_factor != 1.0: + self.mesh.scale(scale_factor) + + def preprocess_data(self) -> None: + """ + Preprocess the loaded mesh data. + """ + if self.preprocessed: + logging.info("Mesh data already preprocessed") + return + if self.mesh is None: + raise ValueError("Mesh is not loaded") + logging.info("Preprocessing data...") + + # Remove duplicate vertices (if any) + if self.mesh.npoints != len(set(map(tuple, self.mesh.points()))): + logging.info("Removing duplicate vertices from the mesh...") + self.mesh.clean() + + # Convert to mm scale first + self._convert_to_mm() + + # Center using center of mass (alternative: bounding box center) + center_of_mass = self.mesh.center_of_mass() + self.mesh.pos(-center_of_mass) + + # Orient into the standard coordinate center + if self.fiducials is not None: + logging.info("Orienting head mesh and fiducials...") + self.aligner = HeadPoseAligner(self.mesh, self.fiducials, self.texture) + self.mesh, self.fiducials = self.aligner.orient_head_mesh() + + # Fill small holes + self.mesh.fill_holes(size=10.0) + + # Smooth + self.mesh.smooth(niter=10, pass_band=0.25, edge_angle=10, feature_angle=30) + + # Decimate + # self.mesh.decimate(0.9) + + # Compute normals (better visualization) + self.mesh.compute_normals() + + self.preprocessed = True + self.mesh_preprocessed = self.mesh.clone() + + def clean_data( + self, + x_margin: Optional[float] = 0.5, + y_top_margin: Optional[float] = 0.25, + y_bottom_margin: Optional[float] = 1.0, + z_margin: Optional[float] = 0.25, + ) -> None: + """ + Clean the loaded mesh data. + """ + if not self.preprocessed: + raise ValueError("Data must be preprocessed before cleaning") + if self.mesh is None: + raise ValueError("Mesh is not loaded") + if self.fiducials is None: + raise ValueError("Fiducials are not loaded") + logging.info("Cleaning data...") + + # Crop to bounding box with margins and remove unwanted objects + self.cleaner = HeadCleaner(self.mesh, self.fiducials, self.texture) + self.cleaner.crop_with_bounding_box( + x_margin=x_margin, + y_top_margin=y_top_margin, + y_bottom_margin=y_bottom_margin, + z_margin=z_margin, + ) + self.cleaner.clean_from_unwanted_objects() + + self.mesh_cleaned = self.cleaner.get_cleaned_mesh() + self.mesh = self.mesh_cleaned.clone() + + def extract_cap_data(self, margin: float = 0.0) -> None: + """ + Extract the cap portion of the mesh (above the fiducial plane). + """ + if self.mesh is None: + raise ValueError("Mesh is not loaded") + if self.fiducials is None: + raise ValueError("Fiducials are not loaded") + + logging.info("Extracting cap data (margin=%.2f)...", margin) + + # Initialize extractor + self.cap_extractor = CapExtractor(self.mesh, self.fiducials, self.texture) + + # Perform cap extraction + self.cap_extractor.extract_cap(margin=margin) + + # Store cap results + self.mesh_extracted = self.cap_extractor.mesh.clone() + self.mesh = self.mesh_extracted.clone() + + # Initialize curvature detector/extractor + self.curvature_extractor = ElectrodeCurvatureDetector(self.mesh) + + # Compute maps + # self.curvatures = self.curvature_extractor.extract_curvatures() + + def capture_data( + self, + output_dir: Union[str, Path], + image_size: Tuple[int, int] = (1024, 1024), + show_fiducials: bool = True, + show_bounding_box: bool = True, + show_coordinate_vectors: bool = True, + show_cap_plane: bool = True, + custom_cmap: Literal["saliency", "probability", "gradient"] = None, + show_axes: bool = True, + ) -> None: + """ + Capture mesh data from multiple views and save screenshots. + """ + if self.mesh is None: + raise ValueError("Mesh is not loaded.") + if self.fiducials is None: + raise ValueError("Fiducials are not loaded.") + mesh = self.mesh.clone() + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + logging.info("Capturing mesh data to %s...", output_dir) + + # Default settings + mesh_alpha = 1.0 + center = np.mean( + [self.fiducials.loc[label].values for label in ["NAS", "INI", "LPA", "RPA"]], + axis=0, + ) + + # If an aligner exists, update center and transparency + if self.aligner is not None: + mesh_alpha = 0.25 + center = self.aligner.origin + + # Custom colormap + cmap = None + if custom_cmap is not None: + cmap_map = { + "saliency": "saliency_map", + "probability": "probability_map", + "gradient": "gradient_directions", + } + + key = cmap_map.get(custom_cmap) + if key is not None: + cmap = self.curvatures[key] + if key == "gradient": + cmap = np.linalg.norm(cmap, axis=1) + + # Initialize capturer + self.capturer = HeadCapturer(mesh, self.fiducials, center, self.texture) + + # Capture multi-view images + self.capturer.capture_all_views( + output_dir=output_dir, + image_size=image_size, + show_cap=show_cap_plane and self.cap_extractor is not None, + cmap=cmap if cmap is not None and custom_cmap is not None else None, + ) + + # Prepare optional overlays + fiducials = self.fiducials if show_fiducials else None + bounding_box = getattr(self.cleaner, "bounding_box", None) if show_bounding_box else None + coordinate_vectors = ( + getattr(self.aligner, "coordinate_vectors", None) if show_coordinate_vectors else None + ) + cap_plane = getattr(self.cap_extractor, "plane", None) if show_cap_plane else None + + # Don't show bounding box when displaying cap plane + if show_cap_plane and self.cap_extractor is not None: + bounding_box = None + + # Curvature candidates + curvatures = None + if custom_cmap is not None and self.curvatures is not None: + curvatures = [candidate["position"] for candidate in self.curvatures["candidates"]] + + # Capture a custom single view with overlays + self.capturer.capture_single_view( + mesh=mesh, + name="custom", + output_dir=output_dir, + center=center, + image_size=image_size, + mesh_alpha=mesh_alpha, + mesh_color=None, + fiducials=fiducials, + bounding_box=bounding_box, + coordinate_vectors=coordinate_vectors, + cap_plane=cap_plane, + cmap=cmap if cmap is not None and custom_cmap is not None else None, + curvatures=curvatures, + show_axes=show_axes, + ) + + def get_plotter( + self, + show_axes: bool = False, + show_fiducials: bool = False, + mesh_alpha: float = 1.0, + ) -> Plotter: + """ + Visualize the loaded mesh data. + """ + if self.mesh is None: + raise ValueError("Mesh is not loaded") + + plotter = Plotter(title="Electrode Localization Kit", size=(1200, 800)) + + # Mesh appearance + self.mesh.alpha(mesh_alpha) + if self.texture is None: + self.mesh.color("black").alpha(0.25) + plotter.add(self.mesh) + + # Fiducials + if show_fiducials and self.fiducials is not None: + fiducial_spheres = [ + Sphere(pos=fiducial, r=2.5, c="red") for fiducial in self.fiducials.values + ] + plotter.add(fiducial_spheres) + + # Global axes + if show_axes: + axes_actor = Axes(self.mesh, xtitle="X", ytitle="Y", ztitle="Z") + plotter.add(axes_actor) + + return plotter + + def get_summary(self) -> Dict[str, Union[int, float]]: + """ + Get a summary of the mesh properties. + """ + summary = { + "mesh_loaded": self.mesh is not None, + "texture_loaded": self.texture is not None, + "fiducials_loaded": self.fiducials is not None, + "preprocessed": self.preprocessed, + } + + if self.mesh is not None: + summary.update( + { + "num_vertices": self.mesh.npoints, + "num_faces": self.mesh.ncells, + "bounds": self.mesh.bounds().tolist(), + "center_of_mass": self.mesh.center_of_mass().tolist(), + } + ) + + if self.fiducials is not None: + summary.update( + { + "fiducials": self.fiducials.index.tolist(), + "num_fiducials": len(self.fiducials), + } + ) + + return summary diff --git a/src/ui/callbacks/_connect.py b/src/ui/callbacks/_connect.py deleted file mode 100644 index a1fc473..0000000 --- a/src/ui/callbacks/_connect.py +++ /dev/null @@ -1,165 +0,0 @@ -from fileio.scan import load_surface, load_texture -from fileio.mri import load_mri -from fileio.locations import load_locations, save_locations_to_file -from processing_handlers.texture_processing import detect_electrodes -from config.electrode_detector import DogParameters, HoughParameters -from config.sizes import ElectrodeSizes - -from callbacks.display import ( - display_surface, - set_surface_alpha, -) - - -def connect_model(self): - self.model.dataChanged.connect(self.refresh_count_indicators) - self.model.rowsInserted.connect(self.refresh_count_indicators) - self.model.rowsRemoved.connect(self.refresh_count_indicators) - - -def connect_fileio_buttons(self): - self.ui.load_surface_button.clicked.connect( - lambda: load_surface( - self.files, - self.views, - self.headmodels, - [ - ("scan", self.ui.headmodel_frame), - ("labeling_main", self.ui.labeling_main_frame), - ], - self.model, - ) - ) - - self.ui.load_texture_button.clicked.connect( - lambda: load_texture( - self.files, - self.views, - self.headmodels, - [ - ("scan", self.ui.headmodel_frame), - ("labeling_main", self.ui.labeling_main_frame), - ], - self.model, - self.electrode_detector, - ) - ) - self.ui.load_mri_button.clicked.connect( - lambda: load_mri( - self.files, - self.views, - self.headmodels, - [("mri", self.ui.mri_frame)], - self.model, - ) - ) - self.ui.load_locations_button.clicked.connect( - lambda: load_locations( - self.files, - self.views, - [("labeling_reference", self.ui.labeling_reference_frame)], - self.model, - ) - ) - - self.ui.export_locations_button.clicked.connect(lambda: save_locations_to_file(self.model)) - - -def connect_texture_buttons(self): - self.ui.display_dog_button.clicked.connect(self.display_dog) - self.ui.display_hough_button.clicked.connect(self.display_hough) - # texture detect electrodes button slot connection - self.ui.compute_electrodes_button.clicked.connect( - lambda: detect_electrodes( - self.headmodels["scan"], # type: ignore - self.electrode_detector, - self.model, - ) - ) - - -def connect_display_surface_buttons(self): - self.ui.display_head_button.clicked.connect(lambda: display_surface(self.views["scan"])) - self.ui.display_mri_button.clicked.connect(lambda: display_surface(self.views["mri"])) - self.ui.label_display_button.clicked.connect( - lambda: display_surface(self.views["labeling_main"]) - ) - - -def connect_alpha_sliders(self): - self.ui.head_alpha_slider.valueChanged.connect( - lambda: set_surface_alpha(self.views["scan"], self.ui.head_alpha_slider.value() / 100) - ) - self.ui.mri_alpha_slider.valueChanged.connect( - lambda: set_surface_alpha(self.views["mri"], self.ui.mri_alpha_slider.value() / 100) - ) - self.ui.mri_head_alpha_slider.valueChanged.connect( - lambda: set_surface_alpha( - self.views["mri"], - self.ui.mri_head_alpha_slider.value() / 100, - actor_index=1, - ) - ) - - -def connect_configuration_boxes(self): - # texture DoG spinbox slot connections - self.ui.kernel_size_spinbox.valueChanged.connect(self.display_dog) - self.ui.sigma_spinbox.valueChanged.connect(self.display_dog) - self.ui.diff_factor_spinbox.valueChanged.connect(self.display_dog) - - # texture Hough spinbox slot connections - self.ui.param1_spinbox.valueChanged.connect(self.display_hough) - self.ui.param2_spinbox.valueChanged.connect(self.display_hough) - self.ui.min_dist_spinbox.valueChanged.connect(self.display_hough) - self.ui.min_radius_spinbox.valueChanged.connect(self.display_hough) - self.ui.max_radius_spinbox.valueChanged.connect(self.display_hough) - - # surface configuration slot connections - self.ui.sphere_size_spinbox.valueChanged.connect(self.update_surf_config) - self.ui.flagposts_checkbox.stateChanged.connect(self.update_surf_config) - self.ui.flagpost_height_spinbox.valueChanged.connect(self.update_surf_config) - self.ui.flagpost_size_spinbox.valueChanged.connect(self.update_surf_config) - - # mri configuration slot connections - self.ui.mri_sphere_size_spinbox.valueChanged.connect(self.update_mri_config) - self.ui.mri_flagposts_checkbox.stateChanged.connect(self.update_mri_config) - self.ui.mri_flagpost_height_spinbox.valueChanged.connect(self.update_mri_config) - self.ui.mri_flagpost_size_spinbox.valueChanged.connect(self.update_mri_config) - - # label configuration slot connections - self.ui.label_sphere_size_spinbox.valueChanged.connect(self.update_reference_labeling_config) - self.ui.label_flagposts_checkbox.stateChanged.connect(self.update_reference_labeling_config) - self.ui.label_flagpost_height_spinbox.valueChanged.connect( - self.update_reference_labeling_config - ) - self.ui.label_flagpost_size_spinbox.valueChanged.connect(self.update_reference_labeling_config) - - -def set_defaults_to_configuration_boxes(self): - # texture DoG spinbox default values - self.ui.kernel_size_spinbox.setValue(DogParameters.KSIZE) - self.ui.sigma_spinbox.setValue(DogParameters.SIGMA) - self.ui.diff_factor_spinbox.setValue(DogParameters.FACTOR) - - # texture Hough spinbox default values - self.ui.param1_spinbox.setValue(HoughParameters.PARAM1) - self.ui.param2_spinbox.setValue(HoughParameters.PARAM2) - self.ui.min_dist_spinbox.setValue(HoughParameters.MIN_DISTANCE) - self.ui.min_radius_spinbox.setValue(HoughParameters.MIN_RADIUS) - self.ui.max_radius_spinbox.setValue(HoughParameters.MAX_RADIUS) - - # electrode size spinbox default values - self.ui.sphere_size_spinbox.setValue(ElectrodeSizes.HEADSCAN_ELECTRODE_SIZE) - self.ui.flagpost_size_spinbox.setValue(ElectrodeSizes.HEADSCAN_FLAGPOST_SIZE) - self.ui.flagpost_height_spinbox.setValue(ElectrodeSizes.HEADSCAN_FLAGPOST_HEIGHT) - - # mri electrode size spinbox default values - self.ui.mri_sphere_size_spinbox.setValue(ElectrodeSizes.MRI_ELECTRODE_SIZE) - self.ui.mri_flagpost_size_spinbox.setValue(ElectrodeSizes.MRI_FLAGPOST_SIZE) - self.ui.mri_flagpost_height_spinbox.setValue(ElectrodeSizes.MRI_FLAGPOST_HEIGHT) - - # label electrode size spinbox default values - self.ui.label_sphere_size_spinbox.setValue(ElectrodeSizes.LABEL_ELECTRODE_SIZE) - self.ui.label_flagpost_size_spinbox.setValue(ElectrodeSizes.LABEL_FLAGPOST_SIZE) - self.ui.label_flagpost_height_spinbox.setValue(ElectrodeSizes.LABEL_FLAGPOST_HEIGHT) diff --git a/src/ui/callbacks/connect/connect_detection.py b/src/ui/callbacks/connect/connect_detection.py new file mode 100644 index 0000000..45cc99d --- /dev/null +++ b/src/ui/callbacks/connect/connect_detection.py @@ -0,0 +1,21 @@ +from processing_handlers.detection_processing import detect_electrodes, process_mesh + + +def connect_detection_buttons(self): + self.ui.process_button.clicked.connect( + lambda: process_mesh( + self.views["scan"], + self.headmodels["scan"], + self.model, + self.loaders, + self.ui, + ) + ) + self.ui.detect_button.clicked.connect( + lambda: detect_electrodes( + self.views["scan"], + self.headmodels["scan"], + self.model, + self.loaders, + ) + ) diff --git a/src/ui/callbacks/connect/connect_fileio.py b/src/ui/callbacks/connect/connect_fileio.py index 17367c8..b8d2452 100644 --- a/src/ui/callbacks/connect/connect_fileio.py +++ b/src/ui/callbacks/connect/connect_fileio.py @@ -14,6 +14,7 @@ def connect_fileio_buttons(self): ("labeling_main", self.ui.labeling_main_frame), ], self.model, + self.ui, ) ) @@ -28,6 +29,7 @@ def connect_fileio_buttons(self): ], self.model, self.electrode_detector, + self.ui, ) ) self.ui.load_mri_button.clicked.connect( diff --git a/src/ui/callbacks/connect/connect_sliders.py b/src/ui/callbacks/connect/connect_sliders.py index 7d2528b..701cb3c 100644 --- a/src/ui/callbacks/connect/connect_sliders.py +++ b/src/ui/callbacks/connect/connect_sliders.py @@ -3,14 +3,10 @@ def connect_alpha_sliders(self): self.ui.head_alpha_slider.valueChanged.connect( - lambda: set_surface_alpha( - self.views["scan"], self.ui.head_alpha_slider.value() / 100 - ) + lambda: set_surface_alpha(self.views["scan"], self.ui.head_alpha_slider.value() / 100) ) self.ui.mri_alpha_slider.valueChanged.connect( - lambda: set_surface_alpha( - self.views["mri"], self.ui.mri_alpha_slider.value() / 100 - ) + lambda: set_surface_alpha(self.views["mri"], self.ui.mri_alpha_slider.value() / 100) ) self.ui.mri_head_alpha_slider.valueChanged.connect( lambda: set_surface_alpha( diff --git a/src/ui/pyloc_main_window.py b/src/ui/pyloc_main_window.py index 21f4783..1dedca9 100755 --- a/src/ui/pyloc_main_window.py +++ b/src/ui/pyloc_main_window.py @@ -1,6 +1,6 @@ # Form implementation generated from reading ui file 'pyloc_main.ui' # -# Created by: PyQt6 UI code generator 6.8.1 +# Created by: PyQt6 UI code generator 6.7.0 # # WARNING: Any manual changes made to this file will be lost when pyuic6 is # run again. Do not edit this file unless you know what you are doing. @@ -586,9 +586,9 @@ def setupUi(self, ELK): self.sphere_size_spinbox = QtWidgets.QDoubleSpinBox(parent=self.groupBox_4) self.sphere_size_spinbox.setMinimumSize(QtCore.QSize(0, 30)) self.sphere_size_spinbox.setMaximumSize(QtCore.QSize(16777215, 30)) - self.sphere_size_spinbox.setDecimals(3) - self.sphere_size_spinbox.setSingleStep(0.005) - self.sphere_size_spinbox.setProperty("value", 0.02) + self.sphere_size_spinbox.setDecimals(1) + self.sphere_size_spinbox.setSingleStep(0.5) + self.sphere_size_spinbox.setProperty("value", 3.5) self.sphere_size_spinbox.setObjectName("sphere_size_spinbox") self.verticalLayout_14.addWidget(self.sphere_size_spinbox) self.flagposts_checkbox = QtWidgets.QCheckBox(parent=self.groupBox_4) @@ -615,11 +615,39 @@ def setupUi(self, ELK): self.flagpost_height_spinbox = QtWidgets.QDoubleSpinBox(parent=self.groupBox_4) self.flagpost_height_spinbox.setMinimumSize(QtCore.QSize(0, 30)) self.flagpost_height_spinbox.setMaximumSize(QtCore.QSize(16777215, 30)) - self.flagpost_height_spinbox.setSingleStep(0.01) - self.flagpost_height_spinbox.setProperty("value", 0.05) + self.flagpost_height_spinbox.setSingleStep(0.5) + self.flagpost_height_spinbox.setProperty("value", 5.0) self.flagpost_height_spinbox.setObjectName("flagpost_height_spinbox") self.verticalLayout_14.addWidget(self.flagpost_height_spinbox) self.verticalLayout_2.addWidget(self.groupBox_4) + self.process_button = QtWidgets.QPushButton(parent=self.widget_2) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Minimum, QtWidgets.QSizePolicy.Policy.Minimum) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(self.process_button.sizePolicy().hasHeightForWidth()) + self.process_button.setSizePolicy(sizePolicy) + self.process_button.setMinimumSize(QtCore.QSize(150, 30)) + self.process_button.setMaximumSize(QtCore.QSize(150, 30)) + font = QtGui.QFont() + font.setPointSize(12) + self.process_button.setFont(font) + self.process_button.setStyleSheet("") + self.process_button.setObjectName("process_button") + self.verticalLayout_2.addWidget(self.process_button) + self.detect_button = QtWidgets.QPushButton(parent=self.widget_2) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Minimum, QtWidgets.QSizePolicy.Policy.Minimum) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(self.detect_button.sizePolicy().hasHeightForWidth()) + self.detect_button.setSizePolicy(sizePolicy) + self.detect_button.setMinimumSize(QtCore.QSize(150, 30)) + self.detect_button.setMaximumSize(QtCore.QSize(150, 30)) + font = QtGui.QFont() + font.setPointSize(12) + self.detect_button.setFont(font) + self.detect_button.setStyleSheet("") + self.detect_button.setObjectName("detect_button") + self.verticalLayout_2.addWidget(self.detect_button) self.restart_button_2 = QtWidgets.QPushButton(parent=self.widget_2) sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Minimum, QtWidgets.QSizePolicy.Policy.Minimum) sizePolicy.setHorizontalStretch(0) @@ -1083,6 +1111,8 @@ def retranslateUi(self, ELK): self.flagposts_checkbox.setText(_translate("ELK", "Flagposts")) self.label_6.setText(_translate("ELK", "Flagposts Size")) self.label_7.setText(_translate("ELK", "Flagposts Height")) + self.process_button.setText(_translate("ELK", "Process")) + self.detect_button.setText(_translate("ELK", "Detect")) self.restart_button_2.setText(_translate("ELK", "Back")) self.proceed_button_2.setText(_translate("ELK", "Proceed")) self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab_2), _translate("ELK", "Headmodel")) diff --git a/src/ui/qt_designer/pyloc_main.ui b/src/ui/qt_designer/pyloc_main.ui index 1bbbc42..df6b288 100644 --- a/src/ui/qt_designer/pyloc_main.ui +++ b/src/ui/qt_designer/pyloc_main.ui @@ -1401,13 +1401,13 @@ color: rgb(255, 255, 255); - 3 + 1 - 0.005000000000000 + 0.500000000000000 - 0.020000000000000 + 3.500000000000000 @@ -1496,16 +1496,82 @@ color: rgb(255, 255, 255); - 0.010000000000000 + 0.500000000000000 - 0.050000000000000 + 5.000000000000000 + + + + + 0 + 0 + + + + + 150 + 30 + + + + + 150 + 30 + + + + + 12 + + + + + + + Process + + + + + + + + 0 + 0 + + + + + 150 + 30 + + + + + 150 + 30 + + + + + 12 + + + + + + + Detect + + + diff --git a/src/ui/state_manager/states.py b/src/ui/state_manager/states.py index 9b7cf04..44f5132 100644 --- a/src/ui/state_manager/states.py +++ b/src/ui/state_manager/states.py @@ -239,7 +239,12 @@ def initialize_processing_states(self): ) self.state_machine[state_name].add_callback(lambda: self.switch_tab(2)) self.state_machine[state_name].add_callback( - lambda: self.update_button_states(restart_button_2=True, proceed_button_2=True) + lambda: self.update_button_states( + restart_button_2=True, + proceed_button_2=True, + process_button=False, + detect_button=False, + ) ) for state_name in [ @@ -252,7 +257,12 @@ def initialize_processing_states(self): ) self.state_machine[state_name].add_callback(lambda: self.switch_tab(2)) self.state_machine[state_name].add_callback( - lambda: self.update_button_states(restart_button_2=False, proceed_button_2=True) + lambda: self.update_button_states( + restart_button_2=False, + proceed_button_2=True, + process_button=False, + detect_button=False, + ) ) # Surface Processing (no locations) @@ -263,7 +273,12 @@ def initialize_processing_states(self): ) self.state_machine[state_name].add_callback(lambda: self.switch_tab(2)) self.state_machine[state_name].add_callback( - lambda: self.update_button_states(restart_button_2=False, proceed_button_2=False) + lambda: self.update_button_states( + restart_button_2=False, + proceed_button_2=False, + process_button=False, + detect_button=False, + ) ) # ------------------------------- diff --git a/src/utils/spatial.py b/src/utils/spatial.py index 7ea5494..8550dd0 100644 --- a/src/utils/spatial.py +++ b/src/utils/spatial.py @@ -1,9 +1,7 @@ import numpy as np -def compute_distance_between_coordinates( - coordinates_1: np.ndarray, coordinates_2: np.ndarray -): +def compute_distance_between_coordinates(coordinates_1: np.ndarray, coordinates_2: np.ndarray): """Returns the distance between two electrodes.""" return np.linalg.norm(coordinates_1 - coordinates_2) @@ -25,7 +23,7 @@ def compute_unit_spherical_coordinates_from_cartesian( def compute_cartesian_coordinates_from_unit_spherical( - spherical_coordinates: list[float] | tuple[float, float] + spherical_coordinates: list[float] | tuple[float, float], ) -> tuple[float, float, float]: """ Computes the cartesian coordinates from the spherical coordinates. @@ -88,9 +86,7 @@ def compute_angular_distance(vector_a: np.ndarray, vector_b: np.ndarray) -> floa Computes the angular distance between two vectors in cartesian coordinates. TODO: Implement the function to compute the angular distance between two vectors in spherical coordinates. """ - val = np.dot(vector_a, vector_b) / ( - np.linalg.norm(vector_a) * np.linalg.norm(vector_b) - ) + val = np.dot(vector_a, vector_b) / (np.linalg.norm(vector_a) * np.linalg.norm(vector_b)) if abs(val) > 1: val = 1 if val > 0 else -1 @@ -131,9 +127,7 @@ def align_vectors( return output_vector -def convert_quaternion_to_rotation_matrix( - Q: tuple[float, float, float, float] -) -> np.ndarray: +def convert_quaternion_to_rotation_matrix(Q: tuple[float, float, float, float]) -> np.ndarray: # w, x, y, z - quaternion components w = Q[0] x = Q[1] diff --git a/src/utils/texture.py b/src/utils/texture.py index 7f19b76..746ff33 100644 --- a/src/utils/texture.py +++ b/src/utils/texture.py @@ -3,85 +3,101 @@ from config.colors import HOUGH_CIRCLES_COLOR + # Hough circle detection functions -def compute_hough_circles(color_image, dog_image, - param1: float, param2: float, - min_distance_between_circles: int, - min_radius: int, max_radius: int, - rgb_circles_color: tuple[int, int, int] = HOUGH_CIRCLES_COLOR) -> tuple[np.ndarray, list[np.uint16] | None]: - """Compute circles on the difference of gaussians (DoG) image using Hough circle detection.""" - - circles = cv.HoughCircles(dog_image, - cv.HOUGH_GRADIENT, 1, minDist=min_distance_between_circles, - param1=param1, param2=param2, - minRadius=min_radius, maxRadius=max_radius) +def compute_hough_circles( + color_image, + dog_image, + param1: float, + param2: float, + min_distance_between_circles: int, + min_radius: int, + max_radius: int, + rgb_circles_color: tuple[int, int, int] = HOUGH_CIRCLES_COLOR, +) -> tuple[np.ndarray, list[np.uint16] | None]: + """Compute circles on the difference of gaussians (DoG) image using Hough circle detection.""" + + circles = cv.HoughCircles( + dog_image, + cv.HOUGH_GRADIENT, + 1, + minDist=min_distance_between_circles, + param1=param1, + param2=param2, + minRadius=min_radius, + maxRadius=max_radius, + ) circles_image = color_image.copy() if circles is not None: - circles = np.uint16(np.around(circles)) # type: ignore - for i in circles[0, :]: # type: ignore + circles = np.uint16(np.around(circles)) # type: ignore + for i in circles[0, :]: # type: ignore center = (i[0], i[1]) radius = i[2] cv.circle(circles_image, center, radius, rgb_circles_color, -1) - - return (circles_image, circles) # type: ignore + + return (circles_image, circles) # type: ignore + # Difference of Gaussians (DoG) texture processing functions -def compute_difference_of_gaussians(image: np.ndarray, - ksize: int, sigma: float, F: float, - threshold_level: int) -> np.ndarray: +def compute_difference_of_gaussians( + image: np.ndarray, ksize: int, sigma: float, F: float, threshold_level: int +) -> np.ndarray: """Compute difference of gaussians (DoG) image.""" - + dog_kernel = compute_dog_kernel(ksize, sigma, F) gray = rgb2gray(image) dog = cv.filter2D(src=gray, ddepth=-1, kernel=dog_kernel) - + return gray2binary(dog, threshold_level) + def compute_dog_kernel(ksize: int, sigma: float, F: float) -> np.ndarray: """Compute difference of gaussians (DoG) kernel.""" - + # get gaussian kernel 1 k1_1d = cv.getGaussianKernel(ksize, sigma) k1 = np.dot(k1_1d, k1_1d.T) - + # get gaussian kernel 2 - k2_1d = cv.getGaussianKernel(ksize, sigma*F) + k2_1d = cv.getGaussianKernel(ksize, sigma * F) k2 = np.dot(k2_1d, k2_1d.T) - + # calculate difference of gaussians return k2 - k1 + # Color conversion functions def rgb2gray(rgb_image: np.ndarray) -> np.ndarray: """Convert RGB image to grayscale.""" - + return cv.cvtColor(rgb_image, cv.COLOR_BGR2GRAY) + def gray2binary(image: np.ndarray, level: int) -> np.ndarray: """Convert grayscale image to binary image.""" - + image[image < level] = 0 image[image > level] = 255 return image + def get_vertex_from_pixels(self, pixels, mesh, image_size): # Helper function to get the vertex from the mesh that corresponds to # the pixel coordinates # # Written by: Aleksij Kraljic, October 29, 2023 - + # extract the vertices from the mesh vertices = mesh.points() - + # extract the uv coordinates from the mesh - uv = mesh.pointdata['material_0'] - + uv = mesh.pointdata["material_0"] + # convert pixels to uv coordinates - uv_image = [(pixels[0]+0.5)/image_size[0], - 1-(pixels[1]+0.5)/image_size[1]] - + uv_image = [(pixels[0] + 0.5) / image_size[0], 1 - (pixels[1] + 0.5) / image_size[1]] + # find index of closest point in uv with uv_image - uv_idx = np.argmin(np.linalg.norm(uv-uv_image, axis=1)) - - return vertices[uv_idx] \ No newline at end of file + uv_idx = np.argmin(np.linalg.norm(uv - uv_image, axis=1)) + + return vertices[uv_idx] diff --git a/src/view/interactive_surface_view.py b/src/view/interactive_surface_view.py index 4067046..3185604 100644 --- a/src/view/interactive_surface_view.py +++ b/src/view/interactive_surface_view.py @@ -11,8 +11,16 @@ from ui.label_dialog import LabelingDialog +import os +from ui.pyloc_main_window import Ui_ELK + +ENV = os.getenv("ELK_ENV", "production") + class InteractiveSurfaceView(SurfaceView): + + required_fiducials = {"NAS", "INI", "LPA", "RPA", "VTX"} + def __init__( self, frame, @@ -21,8 +29,10 @@ def __init__( config={}, model: CapModel | None = None, parent=None, + ui: Ui_ELK = None, ): super().__init__(frame, mesh, modality, config, model, parent) + self.ui = ui self._interaction_state = "x" @@ -160,7 +170,12 @@ def _on_left_click_release(self, evt): electrode = Electrode(point, modality=self.modality[0], label=label) self.model.insert_electrode(electrode) elif self._interaction_state == "E": - dialog = LabelingDialog() + fiducials = [ + fiducial.label for fiducial in self.model.get_fiducials([self.modality[0]]) + ] + unlabeled = list(self.required_fiducials - set(fiducials)) + unlabeled.sort() + dialog = LabelingDialog(unlabeled) dialog.exec() label = dialog.get_electrode_label() electrode = Electrode(point, modality=self.modality[0], label=label, fiducial=True) @@ -188,6 +203,15 @@ def _on_keypress(self, evt): pos="top-left", ) self._plotter.add(self.text_state) + + # Check for all required fiducials + if self.ui: + fiducials = [ + fiducial.label for fiducial in self.model.get_fiducials([self.modality[0]]) + ] + unlabeled = list(self.required_fiducials - set(fiducials)) + if ENV == "development" or not unlabeled: + self.ui.process_button.setEnabled(True) elif evt.keyPressed == "s": self._interaction_state = "s" self._plotter.background(c1="#b1fcb3", c2="white") diff --git a/version.txt b/version.txt deleted file mode 100644 index 39e898a..0000000 --- a/version.txt +++ /dev/null @@ -1 +0,0 @@ -0.7.1