diff --git a/runtime/ops/mapper/__init__.py b/runtime/ops/mapper/__init__.py index ed0a0fcb2..f4e5a93e1 100644 --- a/runtime/ops/mapper/__init__.py +++ b/runtime/ops/mapper/__init__.py @@ -63,5 +63,6 @@ def _import_operators(): from . import video_speech_asr from . import video_subtitle_ocr from . import video_text_ocr + from . import wsi_enhance_operator _import_operators() diff --git a/runtime/ops/mapper/wsi_enhance_operator/README.md b/runtime/ops/mapper/wsi_enhance_operator/README.md new file mode 100644 index 000000000..a3cea4c1e --- /dev/null +++ b/runtime/ops/mapper/wsi_enhance_operator/README.md @@ -0,0 +1,136 @@ +# WSIEnhance Operator + +## Overview + +`wsi_enhance_operator` is a custom mapper operator package for DataMate. + +It includes: + +- operator registration entry +- operator metadata and UI settings +- main pipeline implementation +- WSI reading helpers +- slide segmentation helpers +- patch extraction helpers +- stain normalization helpers +- augmentation helpers + +## Directory Structure + +```text +wsi_enhance_operator/ +├── __init__.py +├── metadata.yml +├── process.py +├── README.md +├── requirements.txt +├── augmentations/ +│ ├── __init__.py +│ └── augmentations.py +├── slidesegmenter/ +│ ├── __init__.py +│ ├── _model_utils.py +│ ├── slidesegmenter.py +│ └── model_files/ +│ └── __init__.py +├── stain_normalization/ +│ ├── __init__.py +│ └── stain_normalization.py +├── wsi_processor/ +│ ├── __init__.py +│ └── wsi_processor.py +└── wsi_reader/ + ├── __init__.py + ├── wsi_reader.py + └── wsi_types.py +``` + +## File Responsibilities + +- `__init__.py`: registers `WSIEnhanceMapper` into DataMate operator registry +- `metadata.yml`: defines operator identity, category, runtime resources, and frontend settings +- `process.py`: main mapper entry, parameter parsing, segmentation, patch extraction, and artifact export +- `augmentations/`: patch augmentation utilities +- `slidesegmenter/`: segmentation model loading and inference helpers +- `stain_normalization/`: stain normalization logic +- `wsi_processor/`: contour and detection post-processing helpers +- `wsi_reader/`: WSI file reading abstraction +- `requirements.txt`: Python dependencies required by this operator package + +## Model Path + +The runtime environment is expected to provide model files under: + +- `/models/WSIEnhance/` + +Default `model_folder`: + +- `2025-10-18` + +## Input Expectations + +The operator accepts a `sample` dictionary. Common supported input fields are: + +- `filePath`: source WSI path +- `image_path`: source WSI path alias +- `source_path`: optional source path alias +- `export_path`, `exportPath`, or `output_dir`: output root directory + +## Main Settings + +Common configurable settings in `metadata.yml` include: + +- `model_folder` +- `thumbnail_size` +- `patch_size` +- `patch_bg_thresh` +- `patch_max_bg_ratio` +- `save_patches` +- `enable_stain_normalize` +- `save_normalized_patches` +- `stain_method` +- `stain_target` +- `enable_augmentation` +- `save_augmented_patches` +- `aug_factor` +- `aug_rotate` +- `aug_flip` +- `aug_color_jitter` + +## Output Layout + +For each input slide, the operator can generate: + +- `segmentation/thumbnail.png` +- `segmentation/thumbnail_overlay.png` +- `segmentation/coords_thumbnail.json` +- `patch_extract/patch_positions.json` +- `patch_extract/pipeline_manifest.json` +- `patch_extract/patches/` when raw patch saving is enabled +- `patch_extract/patches_normalized/` when stain normalization output saving is enabled +- `patch_extract/patches_augmented/` when augmentation output saving is enabled + +## Output Fields + +The operator writes result paths and summary fields back into `sample`. Common output fields include: + +- `thumbnail_path` +- `thumbnail_overlay_path` +- `coords_thumbnail_json` +- `patch_positions_json` +- `pipeline_manifest_json` +- `patch_count` +- `normalized_patch_count` +- `augmented_count` +- `patches_dir` +- `normalized_patches_dir` +- `augmented_patches_dir` + +## Usage Notes + +1. Place the operator directory under `runtime/ops/mapper/wsi_enhance_operator`. +2. Ensure `metadata.yml`, `process.py`, and `__init__.py` are present. +3. Ensure required WSI runtime dependencies are installed, including OpenSlide-related dependencies. +4. Ensure model files are mounted under `/models/WSIEnhance/`. +5. Import the operator package from `runtime/ops/mapper/__init__.py`. +6. Configure parameters from the DataMate frontend or task definition. diff --git a/runtime/ops/mapper/wsi_enhance_operator/__init__.py b/runtime/ops/mapper/wsi_enhance_operator/__init__.py new file mode 100644 index 000000000..4e65e5a30 --- /dev/null +++ b/runtime/ops/mapper/wsi_enhance_operator/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- +""" +WSIEnhance 全幻灯片成像处理算子注册入口 +""" + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module( + module_name='WSIEnhanceMapper', + module_path="ops.mapper.wsi_enhance_operator.process" +) diff --git a/runtime/ops/mapper/wsi_enhance_operator/augmentations/__init__.py b/runtime/ops/mapper/wsi_enhance_operator/augmentations/__init__.py new file mode 100644 index 000000000..17f42dd69 --- /dev/null +++ b/runtime/ops/mapper/wsi_enhance_operator/augmentations/__init__.py @@ -0,0 +1,7 @@ +""" +数据增强模块 +""" + +from .augmentations import Augmenter, AugmentationConfig + +__all__ = ["Augmenter", "AugmentationConfig"] diff --git a/runtime/ops/mapper/wsi_enhance_operator/augmentations/augmentations.py b/runtime/ops/mapper/wsi_enhance_operator/augmentations/augmentations.py new file mode 100644 index 000000000..55d1f25e4 --- /dev/null +++ b/runtime/ops/mapper/wsi_enhance_operator/augmentations/augmentations.py @@ -0,0 +1,311 @@ +""" +数据增强模块:支持 WSI patch 的各种增强操作。 + +支持的增强类型: +1. 几何变换:随机旋转、翻转、弹性形变 +2. 颜色变换:亮度、对比度、饱和度、色调调整 +3. 噪声添加:高斯噪声、椒盐噪声 +4. 模糊变换:高斯模糊、运动模糊 +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional, Tuple +import numpy as np + +try: + import cv2 +except Exception as e: + cv2 = None + _CV2_IMPORT_ERR = e +else: + _CV2_IMPORT_ERR = None + + +@dataclass +class AugmentationConfig: + """数据增强配置""" + # 几何变换 + enable_rotate: bool = True # 是否启用随机旋转 + rotate_range: Tuple[int, int] = (-30, 30) # 旋转角度范围 + enable_flip: bool = True # 是否启用随机翻转 + flip_horizontal: bool = True # 水平翻转 + flip_vertical: bool = True # 垂直翻转 + + # 颜色变换 + enable_color_jitter: bool = True # 是否启用颜色抖动 + brightness_range: Tuple[float, float] = (0.8, 1.2) # 亮度调整范围 + contrast_range: Tuple[float, float] = (0.8, 1.2) # 对比度调整范围 + saturation_range: Tuple[float, float] = (0.8, 1.2) # 饱和度调整范围 + hue_range: Tuple[float, float] = (-0.1, 0.1) # 色调调整范围(-0.5~0.5) + + # 噪声添加 + enable_noise: bool = False # 是否启用噪声 + gaussian_noise_var: float = 0.01 # 高斯噪声方差 + salt_pepper_ratio: float = 0.01 # 椒盐噪声比例 + + # 模糊变换 + enable_blur: bool = False # 是否启用模糊 + blur_kernel_size: int = 5 # 模糊核大小 + blur_sigma: float = 1.5 # 高斯模糊 sigma + + # 弹性形变 + enable_elastic: bool = False # 是否启用弹性形变 + elastic_alpha: float = 34.0 # 弹性形变强度 + elastic_sigma: float = 4.0 # 弹性形变平滑度 + + # 输出配置 + output_size: Optional[Tuple[int, int]] = None # 输出尺寸,None 表示保持原尺寸 + + +class Augmenter: + """数据增强器""" + + def __init__(self, config: Optional[AugmentationConfig] = None): + if cv2 is None: + raise ImportError( + "未安装 OpenCV(cv2),无法进行数据增强。\n" + "请安装依赖:pip install opencv-python-headless\n" + f"底层错误:{_CV2_IMPORT_ERR}" + ) + self.cfg = config or AugmentationConfig() + + def augment(self, image: np.ndarray, seed: Optional[int] = None) -> np.ndarray: + """ + 对图像应用数据增强 + + :param image: 输入图像 (H, W, C), RGB 格式 + :param seed: 随机种子,用于复现 + :return: 增强后的图像 + """ + if image is None or image.size == 0: + return image + + # 设置随机种子 + if seed is not None: + np.random.seed(seed) + + result = image.astype(np.float32) + + # 1. 几何变换 + result = self._apply_geometric(result) + + # 2. 颜色变换 + if self.cfg.enable_color_jitter: + result = self._apply_color_jitter(result) + + # 3. 噪声添加 + if self.cfg.enable_noise: + result = self._apply_noise(result) + + # 4. 模糊变换 + if self.cfg.enable_blur: + result = self._apply_blur(result) + + # 5. 弹性形变 + if self.cfg.enable_elastic: + result = self._apply_elastic(result) + + # 裁剪到目标尺寸 + if self.cfg.output_size is not None: + result = self._crop_to_size(result, self.cfg.output_size) + + # 确保输出为 uint8 + result = np.clip(result, 0, 255).astype(np.uint8) + + return result + + def generate_augmented_batch( + self, + image: np.ndarray, + n: int = 1, + seeds: Optional[List[int]] = None + ) -> List[np.ndarray]: + """ + 生成多个增强版本 + + :param image: 原始图像 + :param n: 生成数量 + :param seeds: 可选的随机种子列表 + :return: 增强图像列表 + """ + results = [] + for i in range(n): + seed = seeds[i] if seeds and i < len(seeds) else None + augmented = self.augment(image, seed=seed) + results.append(augmented) + return results + + def _apply_geometric(self, image: np.ndarray) -> np.ndarray: + """应用几何变换""" + h, w = image.shape[:2] + result = image.copy() + + # 旋转 + if self.cfg.enable_rotate: + angle = np.random.uniform(*self.cfg.rotate_range) + M = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1.0) + result = cv2.warpAffine( + result, M, (w, h), + borderMode=cv2.BORDER_REFLECT_101, + flags=cv2.INTER_LINEAR + ) + + # 翻转 + if self.cfg.enable_flip: + flip_code = -1 # 180 度翻转 + if self.cfg.flip_horizontal and not self.cfg.flip_vertical: + flip_code = 1 # 水平翻转 + elif self.cfg.flip_vertical and not self.cfg.flip_horizontal: + flip_code = 0 # 垂直翻转 + elif np.random.random() < 0.5: + flip_code = 1 + else: + flip_code = 0 + + if flip_code >= 0: + result = cv2.flip(result, flip_code) + + return result + + def _apply_color_jitter(self, image: np.ndarray) -> np.ndarray: + """应用颜色抖动""" + result = image.astype(np.float32) + + # 亮度调整 + brightness_factor = np.random.uniform(*self.cfg.brightness_range) + result = result * brightness_factor + + # 对比度调整 + contrast_factor = np.random.uniform(*self.cfg.contrast_range) + mean = np.mean(result, axis=(0, 1), keepdims=True) + result = (result - mean) * contrast_factor + mean + + # 饱和度调整 + saturation_factor = np.random.uniform(*self.cfg.saturation_range) + if saturation_factor != 1.0: + hsv = cv2.cvtColor(result.astype(np.uint8), cv2.COLOR_RGB2HSV) + h, s, v = cv2.split(hsv) + s = np.clip(s.astype(np.float32) * saturation_factor, 0, 255) + hsv = cv2.merge([h, s.astype(np.uint8), v]) + result = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB).astype(np.float32) + + # 色调调整 + hue_factor = np.random.uniform(*self.cfg.hue_range) + if hue_factor != 0: + hsv = cv2.cvtColor(result.astype(np.uint8), cv2.COLOR_RGB2HSV) + h, s, v = cv2.split(hsv) + h = ((h.astype(np.float32) / 180 + hue_factor) % 1) * 180 + hsv = cv2.merge([h.astype(np.uint8), s, v]) + result = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB).astype(np.float32) + + return result + + def _apply_noise(self, image: np.ndarray) -> np.ndarray: + """应用噪声""" + result = image.copy() + + # 高斯噪声 + if self.cfg.gaussian_noise_var > 0: + noise = np.random.normal(0, np.sqrt(self.cfg.gaussian_noise_var) * 255, result.shape) + result = result + noise + + # 椒盐噪声 + if self.cfg.salt_pepper_ratio > 0: + h, w = result.shape[:2] + num_pixels = int(h * w * self.cfg.salt_pepper_ratio) + + # 盐噪声(白色) + salt_coords = [ + np.random.randint(0, h, num_pixels // 2), + np.random.randint(0, w, num_pixels // 2) + ] + for i in range(len(salt_coords[0])): + result[salt_coords[0][i], salt_coords[1][i]] = 255 + + # 胡椒噪声(黑色) + pepper_coords = [ + np.random.randint(0, h, num_pixels // 2), + np.random.randint(0, w, num_pixels // 2) + ] + for i in range(len(pepper_coords[0])): + result[pepper_coords[0][i], pepper_coords[1][i]] = 0 + + return result + + def _apply_blur(self, image: np.ndarray) -> np.ndarray: + """应用模糊""" + kernel_size = self._odd(self.cfg.blur_kernel_size) + + # 高斯模糊 + result = cv2.GaussianBlur( + image.astype(np.uint8), + (kernel_size, kernel_size), + self.cfg.blur_sigma + ) + + return result.astype(np.float32) + + def _apply_elastic(self, image: np.ndarray) -> np.ndarray: + """应用弹性形变""" + h, w = image.shape[:2] + + # 生成随机位移场 + sigma = self.cfg.elastic_sigma + alpha = self.cfg.elastic_alpha + + # 生成平滑随机场 + dx = cv2.GaussianBlur( + np.random.randn(h, w).astype(np.float32), + (0, 0), + sigma + ) * alpha + + dy = cv2.GaussianBlur( + np.random.randn(h, w).astype(np.float32), + (0, 0), + sigma + ) * alpha + + # 创建映射 + x, y = np.meshgrid(np.arange(w), np.arange(h)) + map_x = (x + dx).astype(np.float32) + map_y = (y + dy).astype(np.float32) + + # 应用弹性形变 + result = cv2.remap( + image.astype(np.uint8), + map_x, map_y, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REFLECT_101 + ) + + return result.astype(np.float32) + + def _crop_to_size(self, image: np.ndarray, size: Tuple[int, int]) -> np.ndarray: + """裁剪到目标尺寸""" + h, w = image.shape[:2] + target_h, target_w = size + + # 计算中心裁剪区域 + start_y = max(0, (h - target_h) // 2) + start_x = max(0, (w - target_w) // 2) + + end_y = min(h, start_y + target_h) + end_x = min(w, start_x + target_w) + + # 如果目标尺寸大于原图,则填充 + if target_h > h or target_w > w: + result = np.zeros((target_h, target_w, image.shape[2]), dtype=image.dtype) + paste_y = max(0, (target_h - h) // 2) + paste_x = max(0, (target_w - w) // 2) + result[paste_y:paste_y+h, paste_x:paste_x+w] = image + return result + + return image[start_y:end_y, start_x:end_x] + + @staticmethod + def _odd(k: int) -> int: + """确保数字为奇数""" + return k if (k % 2 == 1) else (k + 1) diff --git a/runtime/ops/mapper/wsi_enhance_operator/metadata.yml b/runtime/ops/mapper/wsi_enhance_operator/metadata.yml new file mode 100644 index 000000000..f82269668 --- /dev/null +++ b/runtime/ops/mapper/wsi_enhance_operator/metadata.yml @@ -0,0 +1,149 @@ +name: 'WSIEnhance 一体化算子' +description: '在一个算子中完成 WSI 轮廓识别、patch 提取、可选 stain 归一化和可选数据增强。' +language: 'python' +vendor: 'huawei' +raw_id: 'WSIEnhanceMapper' +version: '2.0.0' +modal: 'image' +inputs: 'image' +outputs: 'image' +types: + - 'cleaning' +release: + - '整合 segmentation、patch extract、stain normalize、patch augment 四个阶段为一个算子' + - '支持通过参数控制是否保存原始 patch、是否执行 stain 归一化、是否执行数据增强' +runtime: + memory: 2147483648 + cpu: 1.0 + gpu: 0.1 + npu: 0.1 +settings: + model_folder: + name: '模型版本目录' + description: '位于 /models/WSIEnhance 下的 SlideSegmenter 模型目录。' + type: 'input' + defaultVal: '2025-10-18' + required: false + thumbnail_size: + name: '缩略图最大边长' + description: '用于组织分割的缩略图最大边长。' + type: 'slider' + defaultVal: 3072 + min: 1024 + max: 8192 + step: 256 + patch_size: + name: 'Patch 尺寸' + description: '提取 patch 的尺寸。' + type: 'slider' + defaultVal: 256 + min: 64 + max: 512 + step: 32 + patch_bg_thresh: + name: '背景灰度阈值' + description: '灰度高于该值视为背景。' + type: 'slider' + defaultVal: 210 + min: 150 + max: 240 + step: 5 + patch_max_bg_ratio: + name: '最大背景占比' + description: '超过该背景比例的 patch 会被过滤。' + type: 'slider' + defaultVal: 0.85 + min: 0.5 + max: 0.95 + step: 0.05 + save_patches: + name: '保存原始 Patch' + description: '是否将提取出的原始 patch 落盘到数据集目录。启用 Stain 归一化时该项会自动失效,不再保留原始 patch。' + type: 'switch' + defaultVal: 'true' + required: false + checkedLabel: '保存' + unCheckedLabel: '不保存' + enable_stain_normalize: + name: '启用 Stain 归一化' + description: '是否执行 patch 染色归一化。' + type: 'switch' + defaultVal: 'false' + required: false + checkedLabel: '启用' + unCheckedLabel: '关闭' + save_normalized_patches: + name: '保存归一化 Patch' + description: '启用 stain 归一化后,是否保存归一化结果。' + type: 'switch' + defaultVal: 'true' + required: false + checkedLabel: '保存' + unCheckedLabel: '不保存' + stain_method: + name: '归一化方法' + description: '选择染色归一化方法。' + type: 'select' + defaultVal: 'macenko' + required: false + options: + - label: 'Macenko' + value: 'macenko' + - label: 'Reinhard' + value: 'reinhard' + - label: 'Vahadane' + value: 'vahadane' + stain_target: + name: '目标模板路径' + description: '可选目标模板图像路径。' + type: 'input' + defaultVal: '' + required: false + enable_augmentation: + name: '启用数据增强' + description: '是否执行 patch 数据增强。' + type: 'switch' + defaultVal: 'false' + required: false + checkedLabel: '启用' + unCheckedLabel: '关闭' + save_augmented_patches: + name: '保存增强 Patch' + description: '启用数据增强后,是否保存增强结果。' + type: 'switch' + defaultVal: 'true' + required: false + checkedLabel: '保存' + unCheckedLabel: '不保存' + aug_factor: + name: '增强倍数' + description: '每个输入 patch 生成的增强版本数量。' + type: 'slider' + defaultVal: 1 + min: 1 + max: 10 + step: 1 + aug_rotate: + name: '启用旋转' + description: '是否启用随机旋转。' + type: 'switch' + defaultVal: 'true' + required: false + checkedLabel: '启用' + unCheckedLabel: '关闭' + aug_flip: + name: '启用翻转' + description: '是否启用随机翻转。' + type: 'switch' + defaultVal: 'true' + required: false + checkedLabel: '启用' + unCheckedLabel: '关闭' + aug_color_jitter: + name: '启用颜色抖动' + description: '是否启用亮度、对比度、饱和度扰动。' + type: 'switch' + defaultVal: 'true' + required: false + checkedLabel: '启用' + unCheckedLabel: '关闭' diff --git a/runtime/ops/mapper/wsi_enhance_operator/process.py b/runtime/ops/mapper/wsi_enhance_operator/process.py new file mode 100644 index 000000000..72709335a --- /dev/null +++ b/runtime/ops/mapper/wsi_enhance_operator/process.py @@ -0,0 +1,449 @@ +from __future__ import annotations + +import json +import os +import sys +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import cv2 +import numpy as np + +from datamate.core.base_op import Mapper + +MODELS_ROOT = "/models/WSIEnhance" + + +def _ensure_path() -> None: + script_dir = os.path.dirname(os.path.abspath(__file__)) + if script_dir not in sys.path: + sys.path.insert(0, script_dir) + + +def _as_bool(value: Any, default: bool) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + return str(value).strip().lower() == "true" + + +def _resolve_image_path(sample: Dict[str, Any]) -> str: + for key in ("filePath", "image_path", "source_path", "text"): + value = sample.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return "" + + +def _resolve_export_root(sample: Dict[str, Any], source_path: str) -> str: + for key in ("export_path", "exportPath", "output_dir"): + value = sample.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + if source_path: + return os.path.dirname(source_path) + return "" + + +def _resolve_slide_name(sample: Dict[str, Any], source_path: str) -> str: + file_name = sample.get("fileName") + if isinstance(file_name, str) and file_name.strip(): + stem, _ = os.path.splitext(file_name.strip()) + if stem: + return stem + if source_path: + return os.path.splitext(os.path.basename(source_path))[0] + return "wsi_sample" + + +def _resolve_stage_dir(sample: Dict[str, Any], source_path: str, stage_name: str) -> str: + export_root = _resolve_export_root(sample, source_path) + slide_name = _resolve_slide_name(sample, source_path) + return os.path.abspath(os.path.join(export_root, slide_name, stage_name)) + + +def _save_png(path: str, rgb: np.ndarray) -> None: + Path(path).parent.mkdir(parents=True, exist_ok=True) + bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) + cv2.imwrite(path, bgr) + + +def _contours_to_coords(contours: List[np.ndarray]) -> List[List[Tuple[int, int]]]: + out: List[List[Tuple[int, int]]] = [] + for contour in contours: + pts = contour.squeeze(axis=1) if getattr(contour, "ndim", 0) == 3 else contour + out.append([(int(x), int(y)) for x, y in pts]) + return out + + +def _segmentation_output_to_mask(output: Any, shape: Tuple[int, int]) -> np.ndarray: + if output is None: + return np.zeros(shape, dtype=np.uint8) + mask = np.asarray(output) + if mask.ndim == 3: + mask = mask[..., 0] + return ((mask > 0).astype(np.uint8) * 255) + + +def _mask_to_patch_coords( + tissue_mask: np.ndarray, + wsi_w: int, + wsi_h: int, + patch_size: int, +) -> List[Tuple[int, int]]: + mh, mw = tissue_mask.shape[:2] + scale_x = wsi_w / mw + scale_y = wsi_h / mh + coords: List[Tuple[int, int]] = [] + ys, xs = np.where(tissue_mask > 0) + if len(xs) == 0: + return coords + x_min, x_max = xs.min(), xs.max() + y_min, y_max = ys.min(), ys.max() + for y in range(int(y_min * scale_y), int((y_max + 1) * scale_y), patch_size): + for x in range(int(x_min * scale_x), int((x_max + 1) * scale_x), patch_size): + cx = int((x + patch_size / 2) / scale_x) + cy = int((y + patch_size / 2) / scale_y) + if 0 <= cx < mw and 0 <= cy < mh and tissue_mask[cy, cx] > 0: + coords.append((x, y)) + return coords + + +def _keep_patch(patch_rgb: np.ndarray, patch_bg_thresh: int, patch_max_bg_ratio: float) -> bool: + if patch_rgb is None or patch_rgb.size == 0: + return False + gray = cv2.cvtColor(patch_rgb, cv2.COLOR_RGB2GRAY) + bg_mask = gray > patch_bg_thresh + return float(bg_mask.mean()) <= patch_max_bg_ratio + + +def _resolve_device() -> str: + try: + import torch + + if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)) and torch.npu.is_available(): + return "npu:0" + if torch.cuda.is_available(): + return "cuda:0" + except Exception: + pass + return "cpu" + + +class WSIEnhanceMapper(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_folder = str(kwargs.get("model_folder", "2025-10-18")).strip() or "2025-10-18" + self.thumbnail_size = int(kwargs.get("thumbnail_size", 3072)) + self.patch_size = int(kwargs.get("patch_size", 256)) + self.patch_bg_thresh = int(kwargs.get("patch_bg_thresh", 210)) + self.patch_max_bg_ratio = float(kwargs.get("patch_max_bg_ratio", 0.85)) + + self.save_patches = _as_bool(kwargs.get("save_patches"), True) + + self.enable_stain_normalize = _as_bool(kwargs.get("enable_stain_normalize"), False) + self.save_normalized_patches = _as_bool(kwargs.get("save_normalized_patches"), True) + self.stain_method = str(kwargs.get("stain_method", "macenko")).strip().lower() or "macenko" + self.stain_target = str(kwargs.get("stain_target", "")).strip() + + self.enable_augmentation = _as_bool(kwargs.get("enable_augmentation"), False) + self.save_augmented_patches = _as_bool(kwargs.get("save_augmented_patches"), True) + self.aug_factor = int(kwargs.get("aug_factor", 1)) + self.aug_rotate = _as_bool(kwargs.get("aug_rotate"), True) + self.aug_flip = _as_bool(kwargs.get("aug_flip"), True) + self.aug_color_jitter = _as_bool(kwargs.get("aug_color_jitter"), True) + + self._processor = None + self._segmenter = None + self._augmenter = None + self._normalizer = None + + def _init_components(self) -> None: + if self._processor is not None and self._segmenter is not None: + return + + _ensure_path() + from slidesegmenter.slidesegmenter import SlideSegmenter + from wsi_processor.wsi_processor import ProcessorConfig, WSIProcessor + + model_dir = os.path.join(MODELS_ROOT, self.model_folder) + if not os.path.exists(model_dir): + raise FileNotFoundError(f"SlideSegmenter model directory not found: {model_dir}") + + self._processor = WSIProcessor(ProcessorConfig()) + self._segmenter = SlideSegmenter( + channels_last=True, + tissue_segmentation=True, + pen_marking_segmentation=True, + separate_cross_sections=False, + device=_resolve_device(), + model_folder=self.model_folder, + alternative_directory=MODELS_ROOT, + ) + + if self.enable_stain_normalize and self._normalizer is None: + from stain_normalization.stain_normalization import ( + StainMethod, + StainNormalizationConfig, + StainNormalizer, + ) + + method = { + "macenko": StainMethod.MACENKO, + "reinhard": StainMethod.REINHARD, + "vahadane": StainMethod.VAHADANE, + }.get(self.stain_method, StainMethod.MACENKO) + config = StainNormalizationConfig(method=method) + self._normalizer = StainNormalizer(config) + if self.stain_target and os.path.exists(self.stain_target): + target = cv2.imread(self.stain_target) + if target is not None: + target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB) + self._normalizer.set_target_image(target) + + if self.enable_augmentation and self._augmenter is None: + from augmentations.augmentations import AugmentationConfig, Augmenter + + config = AugmentationConfig( + enable_rotate=self.aug_rotate, + enable_flip=self.aug_flip, + enable_color_jitter=self.aug_color_jitter, + ) + self._augmenter = Augmenter(config) + + def _effective_save_patches(self) -> bool: + # Once stain normalization is enabled, raw patches become transient inputs + # and should not be persisted as final dataset artifacts. + return self.save_patches and not self.enable_stain_normalize + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + try: + self._init_components() + from wsi_reader.wsi_reader import WSIReader + + image_path = _resolve_image_path(sample) + if not image_path or not os.path.exists(image_path): + sample["wsi_enhance_error"] = f"Input WSI not found: {image_path}" + return sample + + segmentation_dir = _resolve_stage_dir(sample, image_path, "segmentation") + extract_dir = _resolve_stage_dir(sample, image_path, "patch_extract") + os.makedirs(segmentation_dir, exist_ok=True) + os.makedirs(extract_dir, exist_ok=True) + + patch_dir = os.path.join(extract_dir, "patches") + normalized_dir = os.path.join(extract_dir, "patches_normalized") + augmented_dir = os.path.join(extract_dir, "patches_augmented") + save_raw_patches = self._effective_save_patches() + + if save_raw_patches: + os.makedirs(patch_dir, exist_ok=True) + if self.enable_stain_normalize: + os.makedirs(normalized_dir, exist_ok=True) + if self.enable_augmentation: + os.makedirs(augmented_dir, exist_ok=True) + + kept_positions: List[Dict[str, int]] = [] + saved_patch_files: List[str] = [] + normalized_files: List[str] = [] + augmented_files: List[str] = [] + patch_count = 0 + normalized_count = 0 + augmented_count = 0 + + with WSIReader(image_path) as reader: + wsi_w, wsi_h = reader.width, reader.height + thumbnail = reader.get_thumbnail((self.thumbnail_size, self.thumbnail_size)) + prediction = self._segmenter.segment(thumbnail.astype(np.float32) / 255.0) + tissue_mask = _segmentation_output_to_mask(prediction.get("tissue"), thumbnail.shape[:2]) + pen_mask = _segmentation_output_to_mask(prediction.get("pen"), thumbnail.shape[:2]) + detection = self._processor.build_detection_from_external_masks( + thumbnail_rgb=thumbnail, + tissue_mask=tissue_mask, + note_mask=pen_mask, + global_stain_mask=pen_mask, + ) + + overlay = thumbnail.copy() + cv2.drawContours(overlay, detection.contours["tissue"], -1, (0, 255, 0), 2) + cv2.drawContours(overlay, detection.contours["note"], -1, (255, 0, 0), 2) + if detection.contours.get("artifact"): + cv2.drawContours(overlay, detection.contours["artifact"], -1, (0, 165, 255), 2) + if detection.contours.get("bubble"): + cv2.drawContours(overlay, detection.contours["bubble"], -1, (0, 0, 255), 2) + + thumb_path = os.path.join(segmentation_dir, "thumbnail.png") + overlay_path = os.path.join(segmentation_dir, "thumbnail_overlay.png") + coords_path = os.path.join(segmentation_dir, "coords_thumbnail.json") + + _save_png(thumb_path, thumbnail) + _save_png(overlay_path, overlay) + + coords = { + "source_path": image_path, + "tissue_contours": _contours_to_coords(detection.contours["tissue"]), + "note_contours": _contours_to_coords(detection.contours["note"]), + "artifact_contours": _contours_to_coords(detection.contours.get("artifact", [])), + "bubble_contours": _contours_to_coords(detection.contours.get("bubble", [])), + } + with open(coords_path, "w", encoding="utf-8") as fh: + json.dump(coords, fh, ensure_ascii=False, indent=2) + + tissue_for_patches = cv2.bitwise_and(detection.tissue_mask, cv2.bitwise_not(detection.note_mask)) + tissue_for_patches = cv2.bitwise_and(tissue_for_patches, cv2.bitwise_not(detection.artifact_mask)) + patch_coords = _mask_to_patch_coords(tissue_for_patches, wsi_w, wsi_h, self.patch_size) + + for x, y in patch_coords: + patch = reader.read_region(x, y, self.patch_size, self.patch_size, level=0) + if not _keep_patch(patch, self.patch_bg_thresh, self.patch_max_bg_ratio): + continue + + patch_count += 1 + kept_positions.append({"x": int(x), "y": int(y)}) + base_name = f"patch_{x}_{y}" + patch_name = f"{base_name}.png" + + if save_raw_patches: + patch_path = os.path.join(patch_dir, patch_name) + _save_png(patch_path, patch) + saved_patch_files.append(patch_path) + + stain_source = patch + if self.enable_stain_normalize and self._normalizer is not None: + normalized = self._normalizer.normalize(patch) + stain_source = normalized + normalized_count += 1 + if self.save_normalized_patches: + normalized_path = os.path.join(normalized_dir, patch_name) + _save_png(normalized_path, normalized) + normalized_files.append(normalized_path) + + if self.enable_augmentation and self._augmenter is not None: + outputs = self._augmenter.generate_augmented_batch(stain_source, n=self.aug_factor) + for idx, aug in enumerate(outputs, start=1): + augmented_count += 1 + if self.save_augmented_patches: + augmented_path = os.path.join(augmented_dir, f"{base_name}_aug{idx}.png") + _save_png(augmented_path, aug) + augmented_files.append(augmented_path) + + patch_positions_path = os.path.join(extract_dir, "patch_positions.json") + with open(patch_positions_path, "w", encoding="utf-8") as fh: + json.dump( + { + "source_path": image_path, + "wsi_size": {"w": int(wsi_w), "h": int(wsi_h)}, + "patch_size": self.patch_size, + "patch_count": patch_count, + "patches": kept_positions, + }, + fh, + ensure_ascii=False, + indent=2, + ) + + stain_manifest_path = "" + if self.enable_stain_normalize: + stain_manifest_path = os.path.join(normalized_dir, "stain_normalize_manifest.json") + with open(stain_manifest_path, "w", encoding="utf-8") as fh: + json.dump( + { + "source_path": image_path, + "input_count": patch_count, + "normalized_count": normalized_count, + "saved_count": len(normalized_files), + "save_normalized_patches": self.save_normalized_patches, + "normalized_files": normalized_files, + }, + fh, + ensure_ascii=False, + indent=2, + ) + + augmentation_manifest_path = "" + if self.enable_augmentation: + augmentation_manifest_path = os.path.join(augmented_dir, "augmentation_manifest.json") + with open(augmentation_manifest_path, "w", encoding="utf-8") as fh: + json.dump( + { + "source_mode": "normalized" if self.enable_stain_normalize else "raw", + "input_count": normalized_count if self.enable_stain_normalize else patch_count, + "augmented_count": augmented_count, + "saved_count": len(augmented_files), + "save_augmented_patches": self.save_augmented_patches, + "augmented_files": augmented_files, + }, + fh, + ensure_ascii=False, + indent=2, + ) + + pipeline_manifest_path = os.path.join(extract_dir, "pipeline_manifest.json") + with open(pipeline_manifest_path, "w", encoding="utf-8") as fh: + json.dump( + { + "source_path": image_path, + "model_root": MODELS_ROOT, + "model_folder": self.model_folder, + "thumbnail_size": self.thumbnail_size, + "patch_size": self.patch_size, + "patch_bg_thresh": self.patch_bg_thresh, + "patch_max_bg_ratio": self.patch_max_bg_ratio, + "save_patches": self.save_patches, + "effective_save_patches": save_raw_patches, + "enable_stain_normalize": self.enable_stain_normalize, + "save_normalized_patches": self.save_normalized_patches, + "stain_method": self.stain_method, + "stain_target": self.stain_target, + "enable_augmentation": self.enable_augmentation, + "save_augmented_patches": self.save_augmented_patches, + "aug_factor": self.aug_factor, + "aug_rotate": self.aug_rotate, + "aug_flip": self.aug_flip, + "aug_color_jitter": self.aug_color_jitter, + "patch_count": patch_count, + "normalized_count": normalized_count, + "augmented_count": augmented_count, + "segmentation_dir": segmentation_dir, + "patch_extract_dir": extract_dir, + }, + fh, + ensure_ascii=False, + indent=2, + ) + + sample["thumbnail_path"] = thumb_path + sample["thumbnail_overlay_path"] = overlay_path + sample["coords_thumbnail_json"] = coords_path + sample["patch_positions_json"] = patch_positions_path + sample["pipeline_manifest_json"] = pipeline_manifest_path + sample["patch_count"] = patch_count + sample["normalized_patch_count"] = normalized_count + sample["augmented_count"] = augmented_count + sample["patches_dir"] = patch_dir if save_raw_patches else "" + sample["normalized_patches_dir"] = normalized_dir if self.enable_stain_normalize and self.save_normalized_patches else "" + sample["augmented_patches_dir"] = augmented_dir if self.enable_augmentation and self.save_augmented_patches else "" + if stain_manifest_path: + sample["stain_normalize_manifest_json"] = stain_manifest_path + if augmentation_manifest_path: + sample["augmentation_manifest_json"] = augmentation_manifest_path + sample["wsi_enhance_metadata"] = { + "model_root": MODELS_ROOT, + "model_folder": self.model_folder, + "thumbnail_size": self.thumbnail_size, + "patch_size": self.patch_size, + "patch_bg_thresh": self.patch_bg_thresh, + "patch_max_bg_ratio": self.patch_max_bg_ratio, + "save_patches": self.save_patches, + "effective_save_patches": save_raw_patches, + "enable_stain_normalize": self.enable_stain_normalize, + "save_normalized_patches": self.save_normalized_patches, + "enable_augmentation": self.enable_augmentation, + "save_augmented_patches": self.save_augmented_patches, + "output_dir": os.path.dirname(segmentation_dir), + } + return sample + except Exception as exc: + sample["wsi_enhance_error"] = str(exc) + return sample diff --git a/runtime/ops/mapper/wsi_enhance_operator/requirements.txt b/runtime/ops/mapper/wsi_enhance_operator/requirements.txt new file mode 100644 index 000000000..a145baed0 --- /dev/null +++ b/runtime/ops/mapper/wsi_enhance_operator/requirements.txt @@ -0,0 +1,8 @@ +torch>=2.0.0 +torch_npu>=2.1.0 +numpy>=1.21.0 +scipy>=1.10.0 +opencv-python-headless>=4.5.0 +Pillow>=8.0.0 +openslide-python>=1.1.0 +huggingface-hub>=0.20.0 diff --git a/runtime/ops/mapper/wsi_enhance_operator/slidesegmenter/__init__.py b/runtime/ops/mapper/wsi_enhance_operator/slidesegmenter/__init__.py new file mode 100644 index 000000000..ce9ac1c9c --- /dev/null +++ b/runtime/ops/mapper/wsi_enhance_operator/slidesegmenter/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2023 Ruben T Lucassen, UMC Utrecht, The Netherlands +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .slidesegmenter import SlideSegmenter + +__all__ = ['SlideSegmenter'] \ No newline at end of file diff --git a/runtime/ops/mapper/wsi_enhance_operator/slidesegmenter/_model_utils.py b/runtime/ops/mapper/wsi_enhance_operator/slidesegmenter/_model_utils.py new file mode 100644 index 000000000..bd6a093b7 --- /dev/null +++ b/runtime/ops/mapper/wsi_enhance_operator/slidesegmenter/_model_utils.py @@ -0,0 +1,661 @@ +# Copyright 2023 Ruben T Lucassen, UMC Utrecht, The Netherlands +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Implementation of network architecture (modified U-Net) in Pytorch. +""" + +import torch +import torch.nn as nn + + +class Block(nn.Module): + + def __init__( + self, + input_channels: int, + output_channels: int, + activation: str, + normalization: str, + residual_connection: bool, + ) -> None: + """ + Implementation of a block of layers. + + Args: + input_channels: Number of channels of the input tensor. + output_channels: Number of channels of the output tensor. + activation: Activation function to non-linearly transform feature maps. + normalization: Type of normalization layer for feature maps. + residual_connection: Indicates whether a residual connection is added. + """ + super().__init__() + + # define activation function + if activation == 'relu': + self.activation = nn.ReLU + elif activation == 'leaky_relu': + self.activation = nn.LeakyReLU + else: + raise ValueError('Invalid argument for activation function.') + + # define normalization layer + if normalization is None: + self.normalization = None + elif 'batch' in normalization: + self.normalization = nn.BatchNorm2d + elif 'instance' in normalization: + self.normalization = nn.InstanceNorm2d + else: + raise ValueError('Invalid argument for normalization layer.') + + # define residual connections attribute + self.residual_connection = residual_connection + + # define layers + self.conv1 = nn.Conv2d( + input_channels, + output_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode='zeros', + bias = False if self.normalization is not None else True, + ) + if self.normalization is not None: + self.norm1 = self.normalization(output_channels) + self.act1 = self.activation(inplace=True) + + self.conv2 = nn.Conv2d( + output_channels, + output_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode='zeros', + bias = False if self.normalization is not None else True, + ) + if self.normalization is not None: + self.norm2 = self.normalization(output_channels) + self.act2 = self.activation(inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Input tensor + Returns: + x: Tensor after operations + """ + x1 = self.conv1(x) + x1 = self.norm1(x1) if self.normalization is not None else x1 + x1 = self.act1(x1) + + x2 = self.conv2(x1) + x2 = self.norm2(x2) if self.normalization is not None else x2 + x2 = self.act2(x2) + x = x1+x2 if self.residual_connection else x2 + + return x + + +class Down(Block): + + def __init__(self, + input_channels: int, + output_channels: int, + activation: str, + normalization: str, + downsample_factor: int, + downsample_method: str, + residual_connection: bool, + ) -> None: + """ + Implementation of a block of layers starting with a downsampling operation. + + Args: + input_channels: Number of channels of the input tensor. + output_channels: Number of channels of the output tensor. + activation: Activation function to non-linearly transform feature maps. + normalization: Type of normalization layer for feature maps. + downsample_factor: Downsampling factor used to reduce the spatial size. + downsample_method: Operation used for downsampling the feature maps. + residual_connection: Indicates whether a residual connection is added. + """ + super().__init__(input_channels, output_channels, activation, + normalization, residual_connection) + + # define downsampling layer as strided convolution + if downsample_method == 'max_pool': + self.downsample = nn.MaxPool2d( + kernel_size=downsample_factor, + stride=downsample_factor, + ) + elif downsample_method == 'strided_conv': + self.downsample = nn.Conv2d( + input_channels, + input_channels, + kernel_size=downsample_factor, + stride=downsample_factor, + padding=0, + ) + elif downsample_method == 'interpolate': + self.downsample = lambda x: nn.functional.interpolate( + x, + scale_factor=1/downsample_factor, + mode='nearest', + ) + else: + raise ValueError('Invalid argument for downsample method.') + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Input tensor. + Returns: + x: Tensor after operations. + """ + x = self.downsample(x) + x1 = self.conv1(x) + x1 = self.norm1(x1) if self.normalization is not None else x1 + x1 = self.act1(x1) + + x2 = self.conv2(x1) + x2 = self.norm2(x2) if self.normalization is not None else x2 + x2 = self.act2(x2) + x = x1+x2 if self.residual_connection else x2 + + return x + + +class Up(Block): + + def __init__(self, + input_channels: int, + shortcut_channels: int, + output_channels: int, + activation: str, + normalization: str, + upsample_factor: int, + upsample_method: str, + residual_connection: bool, + ) -> None: + """ + Implementation of a block of layers starting with a upsampling operation. + + Args: + input_channels: Number of channels of the input tensor. + shortcut_channels: Number of channels of the shortcut tensor. + output_channels: Number of channels of the output tensor. + activation: Activation function to non-linearly transform feature maps. + normalization: Type of normalization layer for feature maps. + upsample_factor: Upsampling factor used for increasing the spatial size. + upsample_method: Operation used for upsampling the feature maps. + residual_connection: Indicates whether a residual connection is added. + """ + super().__init__(input_channels+shortcut_channels, output_channels, + activation, normalization, residual_connection) + + # define additional layers + if upsample_method == 'transposed_conv': + self.upsample = nn.ConvTranspose2d( + input_channels, + input_channels, + kernel_size=upsample_factor, + stride=upsample_factor, + ) + elif upsample_method == 'interpolate': + self.upsample = nn.Upsample( + scale_factor=upsample_factor, + mode='nearest', + ) + else: + raise ValueError('Invalid argument for upsample method.') + + def forward(self, x_down: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x_down: Input tensor from downsampling path though shortcut. + x: Input tensor from upsampling path. + Returns: + x: Tensor after operations. + """ + x = self.upsample(x) + x = torch.cat([x_down, x], dim=1) + x1 = self.conv1(x) + x1 = self.norm1(x1) if self.normalization is not None else x1 + x1 = self.act1(x1) + + x2 = self.conv2(x1) + x2 = self.norm2(x2) if self.normalization is not None else x2 + x2 = self.act2(x2) + x = x1+x2 if self.residual_connection else x2 + + return x + + +class ModifiedUNet(nn.Module): + + def __init__( + self, + input_channels: int, + filters: int = 8, + activation: str = 'relu', + normalization: str = 'instance', + downsample_method: str = 'max_pool', + downsample_factors: list = [2, 2, 2, 2, 2], + upsample_method: str = 'interpolate', + residual_connection: bool = False, + weight_init: str = 'kaiming_normal', + attach_tissue_decoder: bool = True, + attach_pen_decoder: bool = True, + attach_distance_decoder: bool = True, + ) -> None: + """ + Implementation of modified U-Net with a single encoder connected to + three decoders for tissue and pen marking segmentation, as well as + predicting the distance to the centroid for each tissue cross-section. + + Args: + input_channels: Number of channels of the input tensor. + filters: Number of filters used in the first convolutional layer. + Each consecutive layer in the encoder path uses twice as many filters. + Each consecutive layer in the decoder path uses half as many filters. + activation: Activation function to non-linearly transform feature maps. + normalization: Type of normalization layer for feature maps. + downsample_method: Method to downsample feature maps. + downsample_factors: Factors for downsampling the feature maps. + upsample_method: Method to upsample feature maps. + residual_connection: Indicates whether a residual connection is added. + weight_init: Indicates which weight initialization method should be used. + attach_tissue_decoder: Indicates whether the tissue decoder is attached. + attach_pen_decoder: Indicates whether the pen decoder is attached. + attach_distance_decoder: Indicates whether the distance decoder is attached. + """ + super().__init__() + + # define hyperparameters as instance attributes + self.filters = filters + self.activation = activation + self.normalization = normalization + self.downsample_method = downsample_method + self.downsample_factors = downsample_factors + self.upsample_method = upsample_method + self.residual_connection = residual_connection + self.weight_init = weight_init + self.attach_tissue_decoder = attach_tissue_decoder + self.attach_pen_decoder = attach_pen_decoder + self.attach_distance_decoder = attach_distance_decoder + + # check if the sepcified combination of attached decoders is valid + if (not (self.attach_tissue_decoder or self.attach_pen_decoder + or self.attach_distance_decoder)): + raise ValueError('Atleast one decoder must be attached.') + if self.attach_distance_decoder and not self.attach_tissue_decoder: + raise ValueError('The tissue segmentation decoder must be attached' + 'if the distance map decoder is attached.') + + # define the network layers + layers = { + 'block': Block(input_channels, int(self.filters), self.activation, self.normalization, self.residual_connection), + 'down1': Down(int( 1 * self.filters), int( 2 * self.filters), self.activation, self.normalization, self.downsample_factors[0], self.downsample_method, self.residual_connection), + 'down2': Down(int( 2 * self.filters), int( 4 * self.filters), self.activation, self.normalization, self.downsample_factors[1], self.downsample_method, self.residual_connection), + 'down3': Down(int( 4 * self.filters), int( 8 * self.filters), self.activation, self.normalization, self.downsample_factors[2], self.downsample_method, self.residual_connection), + 'down4': Down(int( 8 * self.filters), int( 16 * self.filters), self.activation, self.normalization, self.downsample_factors[3], self.downsample_method, self.residual_connection), + 'down5': Down(int( 16 * self.filters), int( 16 * self.filters), self.activation, None, self.downsample_factors[4], self.downsample_method, self.residual_connection), + } + if self.attach_tissue_decoder: + layers = { + **layers, + 'up_tissue1': Up( int( 16 * self.filters), int( 16 * self.filters), int( 16 * self.filters), self.activation, self.normalization, self.downsample_factors[4], self.upsample_method, self.residual_connection), + 'up_tissue2': Up( int( 16 * self.filters), int( 8 * self.filters), int( 8 * self.filters), self.activation, self.normalization, self.downsample_factors[3], self.upsample_method, self.residual_connection), + 'up_tissue3': Up( int( 8 * self.filters), int( 4 * self.filters), int( 4 * self.filters), self.activation, self.normalization, self.downsample_factors[2], self.upsample_method, self.residual_connection), + 'up_tissue4': Up( int( 4 * self.filters), int( 2 * self.filters), int( 2 * self.filters), self.activation, self.normalization, self.downsample_factors[1], self.upsample_method, self.residual_connection), + 'up_tissue5': Up( int( 2 * self.filters), int( 1 * self.filters), int( 1 * self.filters), self.activation, self.normalization, self.downsample_factors[0], self.upsample_method, self.residual_connection), + 'final_conv_tissue': nn.Conv2d(self.filters, 1, kernel_size=3, padding=1, padding_mode='zeros', stride=1), + } + if self.attach_pen_decoder: + layers = { + **layers, + 'up_pen1': Up( int( 16 * self.filters), int( 16 * self.filters), int( 16 * self.filters), self.activation, self.normalization, self.downsample_factors[4], self.upsample_method, self.residual_connection), + 'up_pen2': Up( int( 16 * self.filters), int( 8 * self.filters), int( 8 * self.filters), self.activation, self.normalization, self.downsample_factors[3], self.upsample_method, self.residual_connection), + 'up_pen3': Up( int( 8 * self.filters), int( 4 * self.filters), int( 4 * self.filters), self.activation, self.normalization, self.downsample_factors[2], self.upsample_method, self.residual_connection), + 'up_pen4': Up( int( 4 * self.filters), int( 2 * self.filters), int( 2 * self.filters), self.activation, self.normalization, self.downsample_factors[1], self.upsample_method, self.residual_connection), + 'up_pen5': Up( int( 2 * self.filters), int( 1 * self.filters), int( 1 * self.filters), self.activation, self.normalization, self.downsample_factors[0], self.upsample_method, self.residual_connection), + 'final_conv_pen': nn.Conv2d(self.filters, 1, kernel_size=3, padding=1, padding_mode='zeros', stride=1), + } + if self.attach_distance_decoder: + layers = { + **layers, + 'up_distance1': Up( int( 16 * self.filters), int( 16 * self.filters), int( 16 * self.filters), self.activation, self.normalization, self.downsample_factors[4], self.upsample_method, self.residual_connection), + 'up_distance2': Up( int( 16 * self.filters), int( 8 * self.filters), int( 8 * self.filters), self.activation, self.normalization, self.downsample_factors[3], self.upsample_method, self.residual_connection), + 'up_distance3': Up( int( 8 * self.filters), int( 4 * self.filters), int( 4 * self.filters), self.activation, self.normalization, self.downsample_factors[2], self.upsample_method, self.residual_connection), + 'up_distance4': Up( int( 4 * self.filters), int( 2 * self.filters), int( 2 * self.filters), self.activation, self.normalization, self.downsample_factors[1], self.upsample_method, self.residual_connection), + 'up_distance5': Up( int( 2 * self.filters), int( 1 * self.filters), int( 1 * self.filters), self.activation, self.normalization, self.downsample_factors[0], self.upsample_method, self.residual_connection), + 'final_conv_distance': nn.Conv2d(self.filters, 2, kernel_size=3, padding=1, padding_mode='zeros', stride=1), + } + self.layers = nn.ModuleDict(layers) + # recursively apply the initialize_weights method + # to all convolutional layers to initialize weights + self.layers.apply(self.initialize_weights) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Input tensor. + + Returns: + out: Tensor after operations. + """ + outputs = {} + + # encoder + x1 = self.layers['block'](x) + x2 = self.layers['down1'](x1) + x3 = self.layers['down2'](x2) + x4 = self.layers['down3'](x3) + x5 = self.layers['down4'](x4) + x = self.layers['down5'](x5) + + # tissue segmentation and distance map decoder + if self.attach_tissue_decoder: + x_tissue = self.layers['up_tissue1'](x5, x) + x_tissue = self.layers['up_tissue2'](x4, x_tissue) + x_tissue = self.layers['up_tissue3'](x3, x_tissue) + x_tissue = self.layers['up_tissue4'](x2, x_tissue) + x_tissue = self.layers['up_tissue5'](x1, x_tissue) + out_tissue = self.layers['final_conv_tissue'](x_tissue) + outputs['tissue'] = out_tissue + + # pen segmentation decoder + if self.attach_pen_decoder: + x_pen = self.layers['up_pen1'](x5, x) + x_pen = self.layers['up_pen2'](x4, x_pen) + x_pen = self.layers['up_pen3'](x3, x_pen) + x_pen = self.layers['up_pen4'](x2, x_pen) + x_pen = self.layers['up_pen5'](x1, x_pen) + out_pen = self.layers['final_conv_pen'](x_pen) + outputs['pen'] = out_pen + + # tissue distance map decoder + if self.attach_distance_decoder: + x_distance = self.layers['up_distance1'](x5, x) + x_distance = self.layers['up_distance2'](x4, x_distance) + x_distance = self.layers['up_distance3'](x3, x_distance) + x_distance = self.layers['up_distance4'](x2, x_distance) + x_distance = self.layers['up_distance5'](x1, x_distance) + out_distance = self.layers['final_conv_distance'](x_distance) + outputs['distance'] = torch.sigmoid(out_tissue)*out_distance + + return outputs + + def initialize_weights(self, layer: torch.nn) -> None: + """ + Initialize the weights using the specified initialization method + if it is a 2D convolutional layer. + + Args: + layer: Torch network layer. + """ + # define dictionary with initialization function and names + init_methods = { + 'xavier_uniform' : nn.init.xavier_uniform_, + 'xavier_normal' : nn.init.xavier_normal_, + 'kaiming_uniform': lambda x: nn.init.kaiming_uniform_( + x, mode='fan_out', nonlinearity='relu', + ), + 'kaiming_normal' : lambda x: nn.init.kaiming_normal_( + x, mode='fan_out', nonlinearity='relu', + ), + 'zeros' : nn.init.zeros_, + } + + if isinstance(layer, nn.Conv2d) == True: + # select the specified weight initialization function and initialize + # the layer weights and biases + if self.weight_init in init_methods.keys(): + init_methods[self.weight_init](layer.weight) + if layer.bias != None: + nn.init.zeros_(layer.bias) + else: + raise ValueError('Invalid argument for initialization method.') + + def __repr__(self): + """ + Returns total and trainable number of parameters of the model. + """ + parameters = 0 + trainable_parameters = 0 + # count the total and trainable number of parameters of the model + for parameter in self.parameters(): + parameters += parameter.numel() + if parameter.requires_grad: + trainable_parameters += parameter.numel() + # create sentence with information about number of parameters + info = (f"Total number of parameters is {parameters:,}, " + f"of which {trainable_parameters:,} are trainable.\n") + + return info + + +class ModifiedUNet2(nn.Module): + + def __init__( + self, + input_channels: int, + filters: int = 8, + activation: str = 'relu', + normalization: str = 'instance', + downsample_method: str = 'max_pool', + downsample_factors: list = [2, 2, 2, 2, 2], + upsample_method: str = 'interpolate', + residual_connection: bool = False, + weight_init: str = 'kaiming_normal', + attach_tissue_decoder: bool = True, + attach_pen_decoder: bool = True, + attach_distance_decoder: bool = True, + ) -> None: + """ + Implementation of modified U-Net with a single encoder connected to + three decoders for tissue and pen marking segmentation, as well as + predicting the distance to the centroid for each tissue cross-section. + + Args: + input_channels: Number of channels of the input tensor. + filters: Number of filters used in the first convolutional layer. + Each consecutive layer in the encoder path uses twice as many filters. + Each consecutive layer in the decoder path uses half as many filters. + activation: Activation function to non-linearly transform feature maps. + normalization: Type of normalization layer for feature maps. + downsample_method: Method to downsample feature maps. + downsample_factors: Factors for downsampling the feature maps. + upsample_method: Method to upsample feature maps. + residual_connection: Indicates whether a residual connection is added. + weight_init: Indicates which weight initialization method should be used. + attach_tissue_decoder: Indicates whether the tissue decoder is attached. + attach_pen_decoder: Indicates whether the pen decoder is attached. + attach_distance_decoder: Indicates whether the distance decoder is attached. + """ + super().__init__() + + # define hyperparameters as instance attributes + self.filters = filters + self.activation = activation + self.normalization = normalization + self.downsample_method = downsample_method + self.downsample_factors = downsample_factors + self.upsample_method = upsample_method + self.residual_connection = residual_connection + self.weight_init = weight_init + self.attach_tissue_decoder = attach_tissue_decoder + self.attach_pen_decoder = attach_pen_decoder + self.attach_distance_decoder = attach_distance_decoder + + # check if the sepcified combination of attached decoders is valid + if (not (self.attach_tissue_decoder or self.attach_pen_decoder + or self.attach_distance_decoder)): + raise ValueError('Atleast one decoder must be attached.') + if self.attach_distance_decoder and not self.attach_tissue_decoder: + raise ValueError('The tissue segmentation decoder must be attached' + 'if the distance map decoder is attached.') + + # define the network layers + layers = { + 'block': Block(input_channels, int(self.filters), self.activation, self.normalization, self.residual_connection), + 'down1': Down(int( 1 * self.filters), int( 2 * self.filters), self.activation, self.normalization, self.downsample_factors[0], self.downsample_method, self.residual_connection), + 'down2': Down(int( 2 * self.filters), int( 4 * self.filters), self.activation, self.normalization, self.downsample_factors[1], self.downsample_method, self.residual_connection), + 'down3': Down(int( 4 * self.filters), int( 8 * self.filters), self.activation, self.normalization, self.downsample_factors[2], self.downsample_method, self.residual_connection), + 'down4': Down(int( 8 * self.filters), int( 16 * self.filters), self.activation, self.normalization, self.downsample_factors[3], self.downsample_method, self.residual_connection), + 'down5': Down(int( 16 * self.filters), int( 16 * self.filters), self.activation, None, self.downsample_factors[4], self.downsample_method, self.residual_connection), + } + if self.attach_tissue_decoder: + layers = { + **layers, + 'up_tissue1': Up( int( 16 * self.filters), int( 16 * self.filters), int( 8 * self.filters), self.activation, self.normalization, self.downsample_factors[4], self.upsample_method, self.residual_connection), + 'up_tissue2': Up( int( 8 * self.filters), int( 8 * self.filters), int( 4 * self.filters), self.activation, self.normalization, self.downsample_factors[3], self.upsample_method, self.residual_connection), + 'up_tissue3': Up( int( 4 * self.filters), int( 4 * self.filters), int( 2 * self.filters), self.activation, self.normalization, self.downsample_factors[2], self.upsample_method, self.residual_connection), + 'up_tissue4': Up( int( 2 * self.filters), int( 2 * self.filters), int( 1 * self.filters), self.activation, self.normalization, self.downsample_factors[1], self.upsample_method, self.residual_connection), + 'up_tissue5': Up( int( 1 * self.filters), int( 1 * self.filters), int( 1 * self.filters), self.activation, self.normalization, self.downsample_factors[0], self.upsample_method, self.residual_connection), + 'final_conv_tissue': nn.Conv2d(self.filters, 1, kernel_size=3, padding=1, padding_mode='zeros', stride=1), + } + if self.attach_pen_decoder: + layers = { + **layers, + 'up_pen1': Up( int( 16 * self.filters), int( 16 * self.filters), int( 8 * self.filters), self.activation, self.normalization, self.downsample_factors[4], self.upsample_method, self.residual_connection), + 'up_pen2': Up( int( 8 * self.filters), int( 8 * self.filters), int( 4 * self.filters), self.activation, self.normalization, self.downsample_factors[3], self.upsample_method, self.residual_connection), + 'up_pen3': Up( int( 4 * self.filters), int( 4 * self.filters), int( 2 * self.filters), self.activation, self.normalization, self.downsample_factors[2], self.upsample_method, self.residual_connection), + 'up_pen4': Up( int( 2 * self.filters), int( 2 * self.filters), int( 1 * self.filters), self.activation, self.normalization, self.downsample_factors[1], self.upsample_method, self.residual_connection), + 'up_pen5': Up( int( 1 * self.filters), int( 1 * self.filters), int( 1 * self.filters), self.activation, self.normalization, self.downsample_factors[0], self.upsample_method, self.residual_connection), + 'final_conv_pen': nn.Conv2d(self.filters, 1, kernel_size=3, padding=1, padding_mode='zeros', stride=1), + } + if self.attach_distance_decoder: + layers = { + **layers, + 'up_distance1': Up( int( 16 * self.filters), int( 16 * self.filters), int( 8 * self.filters), self.activation, self.normalization, self.downsample_factors[4], self.upsample_method, self.residual_connection), + 'up_distance2': Up( int( 8 * self.filters), int( 8 * self.filters), int( 4 * self.filters), self.activation, self.normalization, self.downsample_factors[3], self.upsample_method, self.residual_connection), + 'up_distance3': Up( int( 4 * self.filters), int( 4 * self.filters), int( 2 * self.filters), self.activation, self.normalization, self.downsample_factors[2], self.upsample_method, self.residual_connection), + 'up_distance4': Up( int( 2 * self.filters), int( 2 * self.filters), int( 1 * self.filters), self.activation, self.normalization, self.downsample_factors[1], self.upsample_method, self.residual_connection), + 'up_distance5': Up( int( 1 * self.filters), int( 1 * self.filters), int( 1 * self.filters), self.activation, self.normalization, self.downsample_factors[0], self.upsample_method, self.residual_connection), + 'final_conv_distance': nn.Conv2d(self.filters, 2, kernel_size=3, padding=1, padding_mode='zeros', stride=1), + } + self.layers = nn.ModuleDict(layers) + # recursively apply the initialize_weights method + # to all convolutional layers to initialize weights + self.layers.apply(self.initialize_weights) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Input tensor. + + Returns: + out: Tensor after operations. + """ + outputs = {} + + # encoder + x1 = self.layers['block'](x) + x2 = self.layers['down1'](x1) + x3 = self.layers['down2'](x2) + x4 = self.layers['down3'](x3) + x5 = self.layers['down4'](x4) + x = self.layers['down5'](x5) + + # tissue segmentation and distance map decoder + if self.attach_tissue_decoder: + x_tissue = self.layers['up_tissue1'](x5, x) + x_tissue = self.layers['up_tissue2'](x4, x_tissue) + x_tissue = self.layers['up_tissue3'](x3, x_tissue) + x_tissue = self.layers['up_tissue4'](x2, x_tissue) + x_tissue = self.layers['up_tissue5'](x1, x_tissue) + out_tissue = self.layers['final_conv_tissue'](x_tissue) + outputs['tissue'] = out_tissue + + # pen segmentation decoder + if self.attach_pen_decoder: + x_pen = self.layers['up_pen1'](x5, x) + x_pen = self.layers['up_pen2'](x4, x_pen) + x_pen = self.layers['up_pen3'](x3, x_pen) + x_pen = self.layers['up_pen4'](x2, x_pen) + x_pen = self.layers['up_pen5'](x1, x_pen) + out_pen = self.layers['final_conv_pen'](x_pen) + outputs['pen'] = out_pen + + # tissue distance map decoder + if self.attach_distance_decoder: + x_distance = self.layers['up_distance1'](x5, x) + x_distance = self.layers['up_distance2'](x4, x_distance) + x_distance = self.layers['up_distance3'](x3, x_distance) + x_distance = self.layers['up_distance4'](x2, x_distance) + x_distance = self.layers['up_distance5'](x1, x_distance) + out_distance = self.layers['final_conv_distance'](x_distance) + outputs['distance'] = torch.sigmoid(out_tissue)*out_distance + + return outputs + + def initialize_weights(self, layer: torch.nn) -> None: + """ + Initialize the weights using the specified initialization method + if it is a 2D convolutional layer. + + Args: + layer: Torch network layer. + """ + # define dictionary with initialization function and names + init_methods = { + 'xavier_uniform' : nn.init.xavier_uniform_, + 'xavier_normal' : nn.init.xavier_normal_, + 'kaiming_uniform': lambda x: nn.init.kaiming_uniform_( + x, mode='fan_out', nonlinearity='relu', + ), + 'kaiming_normal' : lambda x: nn.init.kaiming_normal_( + x, mode='fan_out', nonlinearity='relu', + ), + 'zeros' : nn.init.zeros_, + } + + if isinstance(layer, nn.Conv2d) == True: + # select the specified weight initialization function and initialize + # the layer weights and biases + if self.weight_init in init_methods.keys(): + init_methods[self.weight_init](layer.weight) + if layer.bias != None: + nn.init.zeros_(layer.bias) + else: + raise ValueError('Invalid argument for initialization method.') + + def __repr__(self): + """ + Returns total and trainable number of parameters of the model. + """ + parameters = 0 + trainable_parameters = 0 + # count the total and trainable number of parameters of the model + for parameter in self.parameters(): + parameters += parameter.numel() + if parameter.requires_grad: + trainable_parameters += parameter.numel() + # create sentence with information about number of parameters + info = (f"Total number of parameters is {parameters:,}, " + f"of which {trainable_parameters:,} are trainable.\n") + + return info + + +def get_model(model_name: str): + if model_name == 'ModifiedUNet': + return ModifiedUNet + elif model_name == 'ModifiedUNet2': + return ModifiedUNet2 + else: + raise ValueError('Invalid model name.') \ No newline at end of file diff --git a/runtime/ops/mapper/wsi_enhance_operator/slidesegmenter/model_files/__init__.py b/runtime/ops/mapper/wsi_enhance_operator/slidesegmenter/model_files/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/runtime/ops/mapper/wsi_enhance_operator/slidesegmenter/slidesegmenter.py b/runtime/ops/mapper/wsi_enhance_operator/slidesegmenter/slidesegmenter.py new file mode 100644 index 000000000..35256601b --- /dev/null +++ b/runtime/ops/mapper/wsi_enhance_operator/slidesegmenter/slidesegmenter.py @@ -0,0 +1,508 @@ +# Copyright 2023 Ruben T Lucassen, UMC Utrecht, The Netherlands +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility class for instance segmentation of tissue cross-sections +and semantic segmentation of pen markings. +""" + +import json +import os +from math import ceil, floor +from pathlib import Path +from typing import Optional, Union + +import torch +try: + import torch_npu # noqa: F401 +except Exception: + torch_npu = None +import numpy as np +from huggingface_hub import hf_hub_download +from scipy.ndimage import gaussian_filter, maximum_filter + +from ._model_utils import get_model +from . import model_files + + +class SlideSegmenter: + """ + Class for segmenting tissue and pen markings in low resolution (1.25x) + whole slide images. The class is responsible for: + (1) preprocessing (i.e., padded to a valid size). + (2) running model inference to get the segmentations. + (3) post-processing (i.e., cropping to the original size and optionally + separating tissue cross-sections). + """ + + available_models = { + '2023-08-13': {'model_filename': 'model_state_dict.pth', + 'settings_filename': 'settings.json'}, + '2024-01-10': {'model_filename': 'model_state_dict.pth', + 'settings_filename': 'settings.json'}, + '2025-10-18': {'model_filename': 'model_state_dict.pth', + 'settings_filename': 'settings.json'}, + } + + def __init__( + self, + channels_last: bool = True, + tissue_segmentation: bool = True, + pen_marking_segmentation: bool = True, + separate_cross_sections: bool = False, + device: str = 'cpu', + model_folder: str = 'latest', + alternative_directory: Optional[Union[str, Path]] = None, + ) -> None: + """ + Initialize SlideSegmenter instance. + + Args: + channels_last: Indicates whether the input is expected to have + the channels dimension after the spatial dimension. If False, + channels first is assumed. + tissue_segmentation: Indicates whether tissue is segmented. + pen_marking_segmentation: Indicates whether pen markings are segmented. + separate_cross_sections: Indicates whether the segmented tissue + cross-sections are separated. + device: Specifies whether model inference is performed on the cpu or gpu. + model_folder: Name of model subfolder in model_files folder of the package + ('latest' selects the latest model). + alternative_directory: Optionally define an alternative directory + to store the downloaded model files (relevant in case an error + is thrown when the path length to the cache folder exceeds + the maximum lengthon Windows systems). + """ + # create instance attributes + self.channels_last = channels_last + self.tissue_segmentation = tissue_segmentation + self.pen_marking_segmentation = pen_marking_segmentation + self.separate_cross_sections = separate_cross_sections + self.device = device + self.model_folder = model_folder + self.alternative_directory = alternative_directory + self.model = None + self.divisor = None + self.hyperparameters = {} + + # check if the combination of the selected predictive tasks is valid + if not (self.tissue_segmentation or self.pen_marking_segmentation): + raise ValueError('Atleast one of the segmentation tasks must be selected.') + if self.separate_cross_sections and not self.tissue_segmentation: + raise ValueError('The separation of cross-sections can only be ' + 'performed if the tissue is segmented.') + + # convert path of alternative directory to Path instance + if isinstance(self.alternative_directory, str): + self.alternative_directory = Path(alternative_directory) + + # load and configure model + self._load_model() + + def _load_model( + self, + model: Optional[torch.nn.Module] = None, + model_path: Optional[Union[Path, str]] = None, + settings_path: Optional[Union[Path, str]] = None, + ) -> None: + """ + Loads and configures model. + + Args: + model: Model class. + model_path: Path(s) to (model) state dictionary. + settings_path: Path to model settings JSON. + """ + # check whether the combination of input arguments is valid + if model is None and model_path is None and settings_path is None: + # get the latest model folder + if self.alternative_directory is not None: + directory = self.alternative_directory + else: + directory = Path(model_files.__file__).parent + # select the latest model + if self.model_folder == 'latest': + self.model_folder = sorted(list(self.available_models.keys()))[-1] + # check if the model has to be downloaded + downloaded = True + if self.model_folder not in os.listdir(directory): + downloaded = False + else: + for file in self.available_models[self.model_folder].values(): + if file not in os.listdir(directory / self.model_folder): + downloaded = False + # download the model if necessary + if not downloaded: + print((f'Start downloading the "{self.model_folder}" ' + 'model parameters and configuration settings')) + self._download_model(self.model_folder, directory) + # define the model path and settings path + model_filename = self.available_models[self.model_folder]['model_filename'] + model_path = directory / self.model_folder / model_filename + settings_filename = self.available_models[self.model_folder]['settings_filename'] + settings_path = directory / self.model_folder / settings_filename + elif model_path is not None and settings_path is not None: + pass + else: + raise ValueError('Invalid combination of inputs') + + # load model settings + with open(settings_path, 'r') as f: + settings = json.load(f) + + # select model + if model is None: + model = get_model(settings['model_name']) + + # store hyperparameters + if 'hyperparameters' in settings: + self.hyperparameters = settings['hyperparameters'] + + # load the model parameters + model_state_dict = torch.load( + model_path, + map_location=torch.device(self.device), + ) + # check if 'model_state_dict' is one of the keys + if 'model_state_dict' in model_state_dict: + model_state_dict = model_state_dict['model_state_dict'] + + # remove keywords from settings + keywords = [ + 'attach_tissue_decoder', + 'attach_pen_decoder', + 'attach_distance_decoder', + ] + settings['model'] = { + k:v for k, v in settings['model'].items() if k not in keywords + } + # configure model + self.model = model( + **settings['model'], + attach_tissue_decoder=self.tissue_segmentation, + attach_pen_decoder=self.pen_marking_segmentation, + attach_distance_decoder=self.separate_cross_sections, + ) + # remove excess layers from the model state dictionary (in case the pen + # or distance decoder are not used) and load it + model_state_dict = { + name: model_state_dict[name] for name, _ in self.model.named_parameters() + } + self.model.load_state_dict(model_state_dict) + self.model.to(self.device) + self.model.eval() + + # determine by what value the image height and width must be divisible + self.divisor = np.prod(settings['model']['downsample_factors']) + + def _download_model( + self, + model_folder: str, + local_dir: Union[str, Path], + ) -> None: + """ + Download the model parameters and configuration settings from the HuggingFace Hub. + """ + etag_timeout = float(os.environ.get('SLIDESEGMENTER_HF_ETAG_TIMEOUT', '120')) + filenames = list(self.available_models[model_folder].values()) + for filename in filenames: + hf_hub_download( + repo_id='rtlucassen/slidesegmenter', + filename=filename, + subfolder=model_folder, + local_dir=local_dir, + force_download=False, + etag_timeout=etag_timeout, + resume_download=True, + local_dir_use_symlinks=False, + ) + + def change_device(self, device: str) -> None: + """ + Change the device for model inference. + + Args: + device: Specifies whether model inference is performed on the cpu or gpu. + """ + self.device = device + self.model.to(self.device) + + def segment( + self, + image: Union[np.ndarray, torch.Tensor], + tissue_threshold: Optional[Union[float, str]] = 'default', + pen_marking_threshold: Optional[Union[float, str]] = 'default', + return_distance_maps: bool = False, + ) -> Union[np.ndarray, tuple]: + """ + Steps in segmentation pipeline: + (1) Preprocess the image by adding padding to make the length of the + height and width valid. + (2) Predict the tissue and pen marking segmentation for the image. + (3) Post-process the segmentation by cropping it to the original size. + (4) Optionally divide the tissue segmentations into separate cross-sections. + + Args: + image: Whole slide image (at 1.25x) [0.0-1.0] as (height, width, channel) + for channels last or (channel, height, width) for channels first. + tissue_threshold: Threshold value for binarizing the predicted + tissue segmentation ('default': the threshold value based on the + validation set is used, None: the segmentation is not thresholded). + pen_marking_threshold: Threshold value for binarizing the predicted + pen marking segmentation ('default': the threshold value based on the + validation set is used, None: the segmentation is not thresholded). + return_distance_maps: Indicates whether the distance maps are returned. + + Returns: + prediction: Dictionary with the following key-value pairs: + tissue_segmentation: Segmentation for whole slide image [0.0-1.0] + (at 1.25x) as (height, width, channel) for channels last or + (channel, height, width) for channels first. + pen_marking_segmentation: Segmentation for whole slide image [0.0-1.0] + (at 1.25x) as (height, width, channel) for channels last or + (channel, height, width) for channels first. + distance_maps: Image with predicted horizontal and vertical distance + with respect to centroid as (height, width, channel) for channels last + or (channel, height, width) for channels first. + """ + # check image object type, convert to numpy array if necessary + if isinstance(image, torch.Tensor): + image = image.numpy() + elif not isinstance(image, np.ndarray): + raise TypeError('Invalid type of input argument for image.') + + # check if the image intensities are in the range of 0.0-1.0 + if np.min(image) < 0 or np.max(image) > 1: + raise ValueError('Invalid image intensities (must be in the range 0.0-1.0)') + + # check if the image input argument is valid + if len(image.shape) != 3: + raise ValueError('Invalid number of dimensions for input argument.') + + # check if the tissue threshold for binarization is specified in case + # the cross-sections should be separated. + if self.separate_cross_sections and tissue_threshold is None: + raise ValueError('The tissue threshold must be specified if the ' + 'cross-sections should be separated.') + # check if distance maps can be returned + if return_distance_maps and not self.separate_cross_sections: + raise ValueError('Distance maps can only be returned when ' + 'cross-sections should be separated.') + + # change the channels dimension to be the first dimension if necessary + if self.channels_last: + image = image.transpose((2, 0, 1)) + + # determine the height and width of the input image + channels, height, width = image.shape + + # check if the number of channels is valid: + if channels != 3: + raise ValueError('Invalid number of channels for input argument.') + + # determine the total amount of padding necessary + width_pad = (ceil(width / self.divisor)*self.divisor)-width + height_pad = (ceil(height / self.divisor)*self.divisor)-height + + # determine the amount of padding on each side of the image + padding = [ + (0,0), + (floor(height_pad/2), ceil(height_pad/2)), + (floor(width_pad/2), ceil(width_pad/2)), + ] + # add padding to image + pad_kwargs = { + 'array': image, + 'pad_width': padding, + 'mode': self.hyperparameters['padding_mode'], + } + if self.hyperparameters['padding_mode'] == 'constant': + pad_kwargs['constant_values'] = self.hyperparameters['padding_value'] + image = np.pad(**pad_kwargs) + # convert the image to a torch Tensor + image = torch.from_numpy(image).float() + # get the model prediction + with torch.no_grad(): + prediction = self.model(image[None, ...].to(self.device)) + + # independent of the device, bring the prediction to the cpu, + # remove the batch dimension, and crop the padding + top = padding[1][0] + left = padding[2][0] + prediction = {k:v.to('cpu')[0, :, top:top+height, left:left+width] + for k, v in prediction.items()} + + # binarize the segmentations based on the threshold value + if tissue_threshold == 'default': + tissue_threshold = self.hyperparameters['tissue_threshold'] + if self.tissue_segmentation and tissue_threshold is not None: + prediction['tissue'] = torch.sigmoid(prediction['tissue']) + prediction['tissue'] = torch.where(prediction['tissue'] >= tissue_threshold, 1, 0) + + if pen_marking_threshold == 'default': + pen_marking_threshold = self.hyperparameters['pen_marking_threshold'] + if self.pen_marking_segmentation and pen_marking_threshold is not None: + prediction['pen'] = torch.sigmoid(prediction['pen']) + prediction['pen'] = torch.where(prediction['pen'] >= pen_marking_threshold, 1, 0) + + # convert to numpy arrays + prediction = {k: v.numpy() for k, v in prediction.items()} + + # separate the cross-sections based on the predicted distance maps + if self.separate_cross_sections: + if tissue_threshold is None: + raise ValueError('Unable to seperate cross-sections without ' + 'binarizing the tissue segmentation.') + else: + separated_cross_sections, _ = self._separate_cross_sections( + prediction['tissue'][0, ...], + prediction['distance'][0, ...], + prediction['distance'][1, ...], + ) + prediction['tissue'] = separated_cross_sections.transpose((2, 0, 1)) + + # remove distance maps from prediction if necessary + if not return_distance_maps and 'distance' in prediction: + del prediction['distance'] + + # change the last channel to the first channel + if self.channels_last: + prediction = {k: v.transpose((1, 2, 0)) for k, v in prediction.items()} + + return prediction + + def _separate_cross_sections( + self, + segmentation: np.ndarray, + horizontal_distance: np.ndarray, + vertical_distance: np.ndarray, + ) -> tuple[np.ndarray, list[tuple[float, float]]]: + """ + Separate cross-sections in the predicted segmentation map, + based on the predicted horizontal and vertical distance maps. + + Args: + segmentation: Segmentation for whole slide image [uint8] (at 1.25x) + as (height, width). + horizontal_distance: Image with predicted horizontal distance [float32] + with respect to centroid as (height, width). + vertical_distance: Image with predicted vertical distance [float32] + with respect to centroid as (height, width). + + Returns: + nearest_centroid_map: Segmentation for whole slide image [uint8] + (at 1.25x) as (height, width, channel). + centroid_coords: Coordinates of extracted centroids. + """ + # initialize a variable with the image shape + image_shape = segmentation.shape + + # create a vector with the binarized segmentation result for masking + mask = segmentation.astype(bool).reshape((-1,)) + + # create horizontal and vertical grid + vertical_map, horizontal_map = np.meshgrid( + np.linspace(0, image_shape[0]-1, image_shape[0]), + np.linspace(0, image_shape[1]-1, image_shape[1]), + indexing="ij", + ) + # create the centroid maps + distance_factor = self.hyperparameters['distance_factor'] + x_centroid_map = (horizontal_map - (horizontal_distance*distance_factor)) + y_centroid_map = (vertical_map - (vertical_distance*distance_factor)) + + # flatten the centroid map and select only the tissue regions + x_centroid_flat = x_centroid_map.reshape((-1,))[mask] + y_centroid_flat = y_centroid_map.reshape((-1,))[mask] + + # get hyperparameter values from dictionary + sigma = self.hyperparameters['sigma'] + percentile = self.hyperparameters['percentile'] + filter_size = self.hyperparameters['filter_size'] + pixels_per_bin = self.hyperparameters['pixels_per_bin'] + + # determine the number of bins for the histogram + bins = [image_shape[0]//pixels_per_bin, image_shape[1]//pixels_per_bin] + x_centroid_flat = (x_centroid_flat/pixels_per_bin).astype(int) + y_centroid_flat = (y_centroid_flat/pixels_per_bin).astype(int) + # add the top left and bottom right point of the histogram + # this prevents the histogram from removing empty rows and columns, + # which would not change the output but can prevent confusion when + # inspecting the histogram visually. + if True: + x_centroid_flat = np.concatenate([x_centroid_flat, np.array([0, bins[1]])]) + y_centroid_flat = np.concatenate([y_centroid_flat, np.array([0, bins[1]])]) + + # create 2D histogram + histogram, y_edges, x_edges = np.histogram2d( + y_centroid_flat, + x_centroid_flat, + bins=bins, + range=[[0, bins[0]-1], + [0, bins[1]-1]], + ) + # apply Gaussian filtering to decrease local peaks + if sigma is not None: + histogram = gaussian_filter(histogram, sigma=sigma) + histogram_mask = np.where(histogram > np.percentile(histogram, percentile), 1, 0) + max_filtered_histogram = maximum_filter(histogram, filter_size) + maxima = np.where(histogram == max_filtered_histogram, 1, 0)*histogram_mask + + # convert the edges from ranges to the center value + x_bins = np.array([sum(x_edges[i:i+2])/2 for i in range(bins[1])]) + y_bins = np.array([sum(y_edges[i:i+2])/2 for i in range(bins[0])]) + + # get the centroid coordinates + indices = np.argwhere(maxima) + centroids = np.concatenate( + [x_bins[indices[:, 1], None], y_bins[indices[:, 0], None]], + axis=1, + ) + centroid_coords = list(zip((centroids[:, 0]), centroids[:, 1])) + + # combine the x and y centroid maps into one array + predicted_centroids = [ + x_centroid_map[..., None, None], + y_centroid_map[..., None, None], + ] + predicted_centroid_array = np.concatenate(predicted_centroids, axis=-1) + + # flatten the array and select only the tissue regions + predicted_centroid_flat = predicted_centroid_array.reshape((-1,1,2))[mask, ...] + + # for each pixel, determine the distance between the predicted centroid + # and all extracted centroids. Broadcasting is used for efficiency: + # - predicted_centroid_array: [x*y, 1, 2] -> [x*y, N_centroids, 2] + # - centroid_array: [1, N_centroids, 2] -> [x*y, N_centroids, 2] + distance_flat = np.sum((predicted_centroid_flat-centroids[None, ...])**2, axis=-1) + + # determine for each pixel what the nearest centroid is + nearest_centroid_flat = np.argmin(distance_flat, axis=-1)+1 + + # get the x and y coordinates for the pixels in the segmentation for indexing + horizontal_flat = horizontal_map.reshape((-1,)).astype(np.uint16)[mask] + vertical_flat = vertical_map.reshape((-1,)).astype(np.uint16)[mask] + + # convert the nearest centroid vector to the image shape + nearest_centroid_map = np.zeros(image_shape, dtype=np.uint8) + nearest_centroid_map[vertical_flat, horizontal_flat] = nearest_centroid_flat + + # assign each cross-section to a separate channel + index_map = np.tile( + np.arange(1, len(centroid_coords)+1)[None, None, ...], + (*image_shape, 1), + ) + nearest_centroid_map = np.where(nearest_centroid_map[..., None] == index_map, 1, 0) + + return nearest_centroid_map, centroid_coords diff --git a/runtime/ops/mapper/wsi_enhance_operator/stain_normalization/__init__.py b/runtime/ops/mapper/wsi_enhance_operator/stain_normalization/__init__.py new file mode 100644 index 000000000..bdf4b6298 --- /dev/null +++ b/runtime/ops/mapper/wsi_enhance_operator/stain_normalization/__init__.py @@ -0,0 +1,17 @@ +""" +染色归一化模块 +""" + +from .stain_normalization import ( + StainNormalizer, + StainNormalizationConfig, + StainMethod, + StainTemplateManager +) + +__all__ = [ + "StainNormalizer", + "StainNormalizationConfig", + "StainMethod", + "StainTemplateManager" +] diff --git a/runtime/ops/mapper/wsi_enhance_operator/stain_normalization/stain_normalization.py b/runtime/ops/mapper/wsi_enhance_operator/stain_normalization/stain_normalization.py new file mode 100644 index 000000000..a40a99171 --- /dev/null +++ b/runtime/ops/mapper/wsi_enhance_operator/stain_normalization/stain_normalization.py @@ -0,0 +1,405 @@ +""" +染色归一化模块:将 WSI patch 的染色风格统一到目标模板。 + +支持的方法: +1. Macenko 方法 - 基于 SVD 分解组织染色浓度矩阵 +2. Reinhard 方法 - 基于颜色统计特性匹配 +3. Vahadane 方法 - 基于稀疏非负矩阵分解 + +目标染色模板管理: +- 内置常用模板(H&E 标准) +- 支持用户自定义模板图像 +""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Optional, Tuple +import numpy as np + +try: + import cv2 +except Exception as e: + cv2 = None + _CV2_IMPORT_ERR = e +else: + _CV2_IMPORT_ERR = None + + +class StainMethod(Enum): + """染色归一化方法""" + MACENKO = "macenko" + REINHARD = "reinhard" + VAHADANE = "vahadane" + + +@dataclass +class StainNormalizationConfig: + """染色归一化配置""" + method: StainMethod = StainMethod.MACENKO # 默认使用 Macenko + target_image: Optional[np.ndarray] = None # 目标模板图像 + Io: float = 240.0 # 入射光强度,用于计算光密度 + beta: float = 0.15 # Macenko 阈值,用于分离背景 + normalize_background: bool = True # 是否归一化背景 + + +class StainNormalizer: + """染色归一化器""" + + # 标准 H&E 染色矩阵(从文献中获取) + STANDARD_HE_STAIN_MATRIX = np.array([ + [0.5626, 0.2159], + [0.7201, 0.8012], + [0.4062, 0.5581] + ]) + + # 标准 H&E 浓度统计 + STANDARD_HE_CONCENTRATION_STATS = { + "hematoxylin": {"mean": 1.771, "std": 0.156}, + "eosin": {"mean": 1.054, "std": 0.243} + } + + def __init__(self, config: Optional[StainNormalizationConfig] = None): + if cv2 is None: + raise ImportError( + "未安装 OpenCV(cv2),无法进行染色归一化。\n" + "请安装依赖:pip install opencv-python-headless\n" + f"底层错误:{_CV2_IMPORT_ERR}" + ) + self.cfg = config or StainNormalizationConfig() + self._target_stain_matrix: Optional[np.ndarray] = None + self._target_concentration_stats: Optional[Dict] = None + + def set_target_image(self, image: np.ndarray) -> None: + """ + 设置目标模板图像 + + :param image: 目标图像 (H, W, C), RGB 格式 + """ + if self.cfg.method == StainMethod.MACENKO: + self._compute_macenko_target(image) + elif self.cfg.method == StainMethod.REINHARD: + self._compute_reinhard_target(image) + elif self.cfg.method == StainMethod.VAHADANE: + self._compute_vahadane_target(image) + + def normalize(self, image: np.ndarray) -> np.ndarray: + """ + 对图像进行染色归一化 + + :param image: 输入图像 (H, W, C), RGB 格式 + :return: 归一化后的图像 + """ + if image is None or image.size == 0: + return image + + if self.cfg.method == StainMethod.MACENKO: + return self._macenko_normalize(image) + elif self.cfg.method == StainMethod.REINHARD: + return self._reinhard_normalize(image) + elif self.cfg.method == StainMethod.VAHADANE: + return self._vahadane_normalize(image) + else: + raise ValueError(f"不支持的染色归一化方法:{self.cfg.method}") + + def _rgb_to_od(self, rgb: np.ndarray) -> np.ndarray: + """ + RGB 转光密度 (Optical Density) + + OD = log(Io / I) = log(Io) - log(I) + 其中 Io 是入射光强度,I 是透射光强度 + """ + rgb = rgb.astype(np.float32) + # 防止 log(0) + rgb = np.clip(rgb, 1, self.Io) + od = np.log(self.cfg.Io) - np.log(rgb) + return od + + def _od_to_rgb(self, od: np.ndarray) -> np.ndarray: + """ + 光密度转 RGB + """ + rgb = self.cfg.Io * np.exp(-od) + rgb = np.clip(rgb, 0, 255).astype(np.uint8) + return rgb + + def _compute_macenko_target(self, image: np.ndarray) -> None: + """计算 Macenko 方法的目标参数""" + # 从目标图像提取染色矩阵和浓度统计 + od = self._rgb_to_od(image) + od_flat = od.reshape(-1, 3).T # (3, N) + + # 去除背景 + is_not_background = np.all(od_flat > self.cfg.beta, axis=0) + od_filtered = od_flat[:, is_not_background] + + if od_filtered.size == 0 or od_filtered.shape[1] < 10: + self._target_stain_matrix = self.STANDARD_HE_STAIN_MATRIX.copy() + self._target_concentration_stats = self.STANDARD_HE_CONCENTRATION_STATS.copy() + return + + # SVD 分解 + _, _, Vt = np.linalg.svd(od_filtered, full_matrices=False) + + # 取前两个主成分作为染色矩阵 (3, 2) + stain_matrix = Vt[:2, :].T # (2, 3) -> (3, 2) + + # 归一化染色矩阵的列向量 + stain_matrix = stain_matrix / np.linalg.norm(stain_matrix, axis=0, keepdims=True) + + # 计算浓度(添加异常处理) + try: + concentration = np.linalg.lstsq(stain_matrix, od_filtered, rcond=None)[0] + except np.linalg.LinAlgError: + # 如果计算失败,使用标准矩阵 + self._target_stain_matrix = self.STANDARD_HE_STAIN_MATRIX.copy() + self._target_concentration_stats = self.STANDARD_HE_CONCENTRATION_STATS.copy() + return + + # 统计信息 + self._target_concentration_stats = { + "hematoxylin": { + "mean": np.mean(concentration[0, :]), + "std": np.std(concentration[0, :]) + }, + "eosin": { + "mean": np.mean(concentration[1, :]), + "std": np.std(concentration[1, :]) + } + } + + self._target_stain_matrix = stain_matrix + + def _macenko_normalize(self, image: np.ndarray) -> np.ndarray: + """ + Macenko 染色归一化方法 + + 参考:Macenko et al., "A method for normalizing histology slides + for quantitative analysis", ISBI 2009. + """ + # 确保目标参数已初始化 + if self._target_stain_matrix is None or self._target_concentration_stats is None: + self._target_stain_matrix = self.STANDARD_HE_STAIN_MATRIX.copy() + self._target_concentration_stats = self.STANDARD_HE_CONCENTRATION_STATS.copy() + + # RGB 转 OD + od = self._rgb_to_od(image) + od_flat = od.reshape(-1, 3).T # (3, N) + + # 去除背景 + is_not_background = np.all(od_flat > self.cfg.beta, axis=0) + if np.sum(is_not_background) == 0: + # 全是背景,直接返回 + return image.copy() + + od_filtered = od_flat[:, is_not_background] + + # SVD 分解提取源染色矩阵 + _, _, Vt = np.linalg.svd(od_filtered, full_matrices=False) + source_stain_matrix = Vt[:2, :].T # (3, 2) + + # 归一化染色矩阵的列向量 + source_stain_matrix = source_stain_matrix / np.linalg.norm(source_stain_matrix, axis=0, keepdims=True) + + # 计算浓度 + try: + concentration = np.linalg.lstsq(source_stain_matrix, od_filtered, rcond=None)[0] + except np.linalg.LinAlgError: + return image.copy() + + # 源图像浓度统计 + source_stats = { + "hematoxylin": { + "mean": np.mean(concentration[0, :]), + "std": np.std(concentration[0, :]) + 1e-6 + }, + "eosin": { + "mean": np.mean(concentration[1, :]), + "std": np.std(concentration[1, :]) + 1e-6 + } + } + + # 归一化浓度:将源浓度映射到目标分布 + normalized_concentration = np.zeros_like(concentration) + for i, stain_name in enumerate(["hematoxylin", "eosin"]): + normalized_concentration[i, :] = ( + (concentration[i, :] - source_stats[stain_name]["mean"]) / + source_stats[stain_name]["std"] * + self._target_concentration_stats[stain_name]["std"] + + self._target_concentration_stats[stain_name]["mean"] + ) + + # 使用目标染色矩阵重建 OD + normalized_od = np.dot(self._target_stain_matrix, normalized_concentration) + + # 重建图像 + result = np.zeros_like(image) + result_flat = result.reshape(-1, 3).T + + # 非背景区域使用归一化后的 OD + result_flat[:, is_not_background] = self.cfg.Io * np.exp(-normalized_od) + + # 背景区域保持原样 + bg_mask = ~is_not_background + if np.any(bg_mask): + result_flat[:, bg_mask] = image.reshape(-1, 3).T[:, bg_mask] + + result = np.clip(result, 0, 255).astype(np.uint8) + return result + + def _compute_reinhard_target(self, image: np.ndarray) -> None: + """ + 计算 Reinhard 方法的目标参数 + + Reinhard 方法基于 LAB 颜色空间的统计特性 + """ + # 转 LAB + bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB).astype(np.float32) + + # 计算通道统计 + self._target_concentration_stats = { + "L": {"mean": np.mean(lab[:, :, 0]), "std": np.std(lab[:, :, 0]) + 1e-6}, + "A": {"mean": np.mean(lab[:, :, 1]), "std": np.std(lab[:, :, 1]) + 1e-6}, + "B": {"mean": np.mean(lab[:, :, 2]), "std": np.std(lab[:, :, 2]) + 1e-6} + } + + def _reinhard_normalize(self, image: np.ndarray) -> np.ndarray: + """ + Reinhard 染色归一化方法 + + 参考:Reinhard et al., "Color transfer between images", + IEEE Computer Graphics and Applications 2001. + """ + # 转 LAB + bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB).astype(np.float32) + + # 初始化目标参数(如果未设置) + if self._target_concentration_stats is None: + self._compute_reinhard_target(image) + # 使用标准值 + if self._target_concentration_stats is None: + self._target_concentration_stats = { + "L": {"mean": 127.0, "std": 50.0}, + "A": {"mean": 0.0, "std": 50.0}, + "B": {"mean": 0.0, "std": 50.0} + } + + # 计算源图像统计 + source_stats = { + "L": {"mean": np.mean(lab[:, :, 0]), "std": np.std(lab[:, :, 0]) + 1e-6}, + "A": {"mean": np.mean(lab[:, :, 1]), "std": np.std(lab[:, :, 1]) + 1e-6}, + "B": {"mean": np.mean(lab[:, :, 2]), "std": np.std(lab[:, :, 2]) + 1e-6} + } + + # 归一化每个通道 + for i, channel in enumerate(["L", "A", "B"]): + lab[:, :, i] = ( + (lab[:, :, i] - source_stats[channel]["mean"]) / + source_stats[channel]["std"] * + self._target_concentration_stats[channel]["std"] + + self._target_concentration_stats[channel]["mean"] + ) + + # 转回 RGB + lab = np.clip(lab, 0, 255).astype(np.uint8) + bgr_result = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) + result = cv2.cvtColor(bgr_result, cv2.COLOR_BGR2RGB) + + return result + + def _compute_vahadane_target(self, image: np.ndarray) -> None: + """ + 计算 Vahadane 方法的目标参数 + + 使用稀疏非负矩阵分解 (SNMF) 提取染色矩阵 + """ + # 简化实现:使用 Macenko 的结果作为近似 + self._compute_macenko_target(image) + + def _vahadane_normalize(self, image: np.ndarray) -> np.ndarray: + """ + Vahadane 染色归一化方法 + + 参考:Vahadane et al., "Structure-preserving color normalization + and sparse stain separation for histological images", IEEE TMI 2016. + """ + # 简化实现:使用 Macenko 方法 + return self._macenko_normalize(image) + + @property + def Io(self) -> float: + return self.cfg.Io + + @Io.setter + def Io(self, value: float) -> None: + self.cfg.Io = value + + @property + def beta(self) -> float: + return self.cfg.beta + + @beta.setter + def beta(self, value: float) -> None: + self.cfg.beta = value + + +class StainTemplateManager: + """染色模板管理器""" + + def __init__(self): + self._templates: Dict[str, np.ndarray] = {} + self._stats: Dict[str, Dict] = {} + + def add_template(self, name: str, image: np.ndarray) -> None: + """添加模板图像""" + self._templates[name] = image + self._stats[name] = { + "mean": np.mean(image, axis=(0, 1)), + "std": np.std(image, axis=(0, 1)) + } + + def get_template(self, name: str) -> Optional[np.ndarray]: + """获取模板图像""" + return self._templates.get(name) + + def get_template_names(self) -> List[str]: + """获取所有模板名称""" + return list(self._templates.keys()) + + def remove_template(self, name: str) -> bool: + """删除模板""" + if name in self._templates: + del self._templates[name] + del self._stats[name] + return True + return False + + def load_from_file(self, name: str, file_path: str) -> None: + """从文件加载模板""" + if cv2 is None: + raise ImportError("OpenCV 未安装") + + image = cv2.imread(file_path) + if image is None: + raise ValueError(f"无法加载模板图像:{file_path}") + + # BGR 转 RGB + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + self.add_template(name, image_rgb) + + def save_to_file(self, name: str, file_path: str) -> bool: + """保存模板到文件""" + template = self._templates.get(name) + if template is None: + return False + + if cv2 is None: + raise ImportError("OpenCV 未安装") + + # RGB 转 BGR 保存 + template_bgr = cv2.cvtColor(template, cv2.COLOR_RGB2BGR) + cv2.imwrite(file_path, template_bgr) + return True diff --git a/runtime/ops/mapper/wsi_enhance_operator/wsi_processor/__init__.py b/runtime/ops/mapper/wsi_enhance_operator/wsi_processor/__init__.py new file mode 100644 index 000000000..ef9bc69f6 --- /dev/null +++ b/runtime/ops/mapper/wsi_enhance_operator/wsi_processor/__init__.py @@ -0,0 +1,4 @@ +from .wsi_processor import WSIProcessor, ProcessorConfig, DetectionResult + +__all__ = ["WSIProcessor", "ProcessorConfig", "DetectionResult"] + diff --git a/runtime/ops/mapper/wsi_enhance_operator/wsi_processor/wsi_processor.py b/runtime/ops/mapper/wsi_enhance_operator/wsi_processor/wsi_processor.py new file mode 100644 index 000000000..75c4f5de7 --- /dev/null +++ b/runtime/ops/mapper/wsi_enhance_operator/wsi_processor/wsi_processor.py @@ -0,0 +1,1712 @@ +""" +【算法层】WSIProcessor:只负责算图,不关心图从哪读。 + +输入:RGB 缩略图 (numpy array, HxWx3) +输出: +- tissue_mask / bubble_mask / note_mask / artifact_mask (uint8 0/255) +- 轮廓(用于可视化与导出坐标) + +说明: +- tissue:HSV 饱和度/亮度阈值,形态学后得到组织区域。 +- note(笔迹):仅保留组织轮廓内的笔迹。 +- artifact(伪影):组织轮廓内与主色偏差过大的区域(LAB delta E),如红褐色染色异常。 +- bubble:可选,简单气泡检测。 +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import numpy as np + +try: + import cv2 +except Exception as e: # pragma: no cover + cv2 = None + _CV2_IMPORT_ERR = e + + +@dataclass +class ProcessorConfig: + # tissue(尽量包住整块组织,含淡染/脂肪/浅粉区) + sat_thresh: int = 8 # 饱和度下限,更低以包含很淡的粉/近白区 + val_max: int = 225 # 亮度上限,更高以包含左侧/底部浅粉 + tissue_min_area: int = 3000 # 缩略图尺度下过滤碎片(像素数) + tissue_close_kernel: int = 51 # 闭运算核,大核糊住脂肪、轮廓圆润少锯齿(可试 45–61) + tissue_open_kernel: int = 3 # 开运算核,偏小以免咬掉细组织/连接 + tissue_fill_holes: bool = True + tissue_merge_dilate: int = 17 # 合并邻近碎块:膨胀像素数,越大越易成整块 + tissue_final_close_kernel: int = 61 # 最终平滑轮廓用闭运算核(主要针对脂肪“海岸线”) + # 浅灰/无色细长杂质过滤(划痕/纤维/盖玻片边缘):亮但几乎无饱和度 → 当背景剔除 + tissue_gray_v_min: int = 200 # “很亮”的下限 + tissue_gray_s_max: int = 14 # “几乎无色”的上限(坏死浅粉通常 S 会更高) + # 细长条伪组织(扫描划痕/纤维/碎屑)过滤:防止背景细线被当组织 + tissue_remove_line_artifacts: bool = True + tissue_line_max_thickness: int = 40 # 连通域短边<=该值视为“细” + tissue_line_min_aspect: float = 5.0 # 长宽比>=该值视为“长条” + tissue_line_max_area: int = 1500000 # 面积过大不按长条删除,避免误删真正长条组织 + tissue_line_open_kernel: int = 31 # 用于分离“细长残差”的开运算核(越大越能抹掉细线) + + # note/pen mark + note_val_max: int = 30 # 黑色阈值:更严格,避免误杀深紫色细胞核 + note_sat_max: int = 80 # 可适当限制饱和度,避免把深色组织当笔迹 + note_val_strict: int = 20 # 极暗像素(例如角落黑块)无视饱和度直接判为笔迹/伪影 + note_dark_val_max: int = 58 # 组织内仅“很暗”的像素才强制为笔迹 + note_dark_sat_max: int = 45 # 组织内“黑/灰笔迹”通常低饱和;深紫组织饱和度高,避免误标 + note_close_kernel: int = 5 + note_edge_exclusion: int = 11 + note_edge_overlap_ratio: float = 0.45 + note_edge_keep_min_aspect: float = 3.2 + note_edge_min_area: int = 35 + note_min_area: int = 25 # 笔迹连通域最小面积(像素),过滤细胞核大小的孤立点 + # 蓝墨水:HSV 色相 + 高饱和 + “蓝通道占优(B≫R,G)”联合判定,避免把深紫组织当笔迹 + ink_blue_h_min: int = 85 + ink_blue_h_max: int = 140 + ink_blue_s_min: int = 70 + ink_blue_v_max: int = 230 + ink_blue_b_over_r: int = 25 # B > R + delta + ink_blue_b_over_g: int = 15 # B > G + delta + ink_blue_expand_dilate: int = 5 + ink_blue_expand_h_min: int = 80 + ink_blue_expand_h_max: int = 150 + ink_blue_expand_s_min: int = 55 + ink_blue_expand_b_over_r: int = 16 + ink_blue_expand_b_over_g: int = 10 + ink_blue_grow_dilate: int = 5 + ink_blue_grow_h_min: int = 78 + ink_blue_grow_h_max: int = 155 + ink_blue_grow_s_min: int = 50 + ink_blue_grow_b_over_r: int = 14 + ink_blue_grow_b_over_g: int = 8 + ink_blue_grow_v_max: int = 230 + # 额外“墨迹样”检测:深色、饱和度高、呈细长条,且不在组织内 + ink_val_max: int = 120 + ink_sat_min: int = 80 + ink_min_area: int = 30 + ink_min_aspect: float = 4.0 + + # “细脖子”切断(连接两块组织的浅色/脂肪带),默认偏保守避免把整块组织切碎 + neck_val_min: int = 200 # 脖子区域亮度下限(越大=只切非常亮的桥) + neck_sat_max: int = 50 # 脖子区域饱和度上限(偏灰/白) + neck_min_area: int = 400 # 脖子最小面积,偏大以免误切真实组织连接 + neck_max_thickness: int = 28 # 只切很细的桥(像素),避免切断脂肪等宽连接 + + # 细桥断开(针对内部窄连接),核越小越保留连接、组织越完整 + bridge_kernel: int = 9 # odd,默认偏小以减少碎片化 + + # bubble(默认开启一个很轻的检测,可按需关闭) + enable_bubble: bool = False + bubble_min_area: int = 200 + + # 伪影:仅针对性检测,避免误杀正常深色/多色组织 + enable_artifact: bool = True + artifact_lab_dev_thresh: float = 42.0 # 通用颜色偏差阈值,设高以免深紫/鲜红等正常组织被标成伪影 + artifact_min_area: int = 2000 # 最小连通面积,只标大面积空白/异常,避免坏死区等被标蓝 + artifact_open_kernel: int = 5 # 形态学开运算核,去毛刺 + artifact_bg_v_min: int = 235 # 伪影/空白:近纯白亮度下限(V 通道) + artifact_bg_s_max: int = 12 # 伪影/空白:近纯白饱和度上限(S 通道,背景≈0) + # 深紫/蓝紫高密度组织保护:即便很暗也应当算组织,绝不标为伪影 + artifact_purple_h_min: int = 115 # OpenCV H(0-180) 紫/蓝紫下限 + artifact_purple_h_max: int = 175 # OpenCV H(0-180) 紫/蓝紫上限 + artifact_purple_s_min: int = 20 # 至少有一定饱和度,避免把灰黑当紫 + artifact_purple_v_max: int = 120 # “很暗”的上限,主要保护深紫块 + # 组织折叠:黑红色细长“带子”;若算轮廓/面积则保留为组织(不标蓝),若切图训练可标蓝剔除 + enable_folding_artifact: bool = True # 是否检测折叠 + treat_folding_as_tissue: bool = True # True=折叠算绿(组织),False=折叠标蓝(剔除) + folding_L_max: int = 70 # LAB L 上限,越暗越可能是折叠 + folding_a_min: int = 120 # LAB a 下限(OpenCV 中 128 为中性,>128 偏红),排除蓝紫 + folding_min_aspect: float = 2.5 # 长宽比下限,细长带状才当折叠 + folding_min_area: int = 400 # 折叠连通域最小面积 + + global_stain_min_area: int = 120 + global_stain_dark_v_max: int = 60 + global_stain_dark_s_max: int = 95 + global_stain_dark_min_area: int = 18 + global_stain_dark_open_kernel: int = 3 + global_stain_dark_expand_dilate: int = 17 + global_stain_red_s_min: int = 165 + global_stain_red_v_max: int = 185 + global_stain_red_min_area: int = 60 + global_stain_red_r_over_g: int = 60 + global_stain_red_r_over_b: int = 50 + global_stain_red_expand_dilate: int = 5 + global_stain_red_expand_s_min: int = 135 + global_stain_red_expand_r_over_g: int = 42 + global_stain_red_expand_r_over_b: int = 34 + global_stain_green_h_min: int = 35 + global_stain_green_h_max: int = 95 + global_stain_green_s_min: int = 70 + global_stain_green_v_max: int = 245 + global_stain_green_min_area: int = 60 + global_stain_green_g_over_r: int = 22 + global_stain_green_g_over_b: int = 14 + global_stain_green_expand_dilate: int = 7 + global_stain_green_expand_h_min: int = 30 + global_stain_green_expand_h_max: int = 110 + global_stain_green_expand_s_min: int = 55 + global_stain_green_expand_g_over_r: int = 12 + global_stain_green_expand_g_over_b: int = 8 + global_stain_green_grow_dilate: int = 13 + global_stain_green_grow_h_min: int = 25 + global_stain_green_grow_h_max: int = 115 + global_stain_green_grow_s_min: int = 28 + global_stain_green_grow_g_over_r: int = 2 + global_stain_green_grow_g_over_b: int = -10 + global_stain_green_close_kernel: int = 9 + global_stain_compact_area_max: int = 300 + global_stain_compact_keep_min_aspect: float = 3.0 + global_stain_red_compact_area_max: int = 2600 + global_stain_red_compact_keep_min_aspect: float = 2.1 + global_stain_purple_tissue_area_max: int = 1600 + global_stain_purple_tissue_max_aspect: float = 2.25 + global_stain_purple_tissue_h_min: int = 120 + global_stain_purple_tissue_h_max: int = 145 + global_stain_purple_tissue_s_max: int = 90 + global_stain_purple_tissue_v_min: int = 95 + global_stain_purple_tissue_overlap_min: float = 0.6 + stain_internal_score_k: float = 1.9 + stain_internal_min_area: int = 180 + stain_internal_open_kernel: int = 0 + stain_internal_score_alpha: float = 0.95 + stain_internal_v_max: int = 135 + stain_internal_s_min: int = 30 + stain_internal_h_red_hmax: int = 10 + stain_internal_h_red_hmin: int = 160 + stain_internal_h_purple_min: int = 133 + stain_internal_h_purple_max: int = 170 + stain_tissue_edge_exclusion: int = 9 + tissue_score_he_weight: float = 0.72 + tissue_score_od_weight: float = 0.28 + tissue_score_blur_sigma: float = 4.0 + tissue_he_loose_percentile: float = 25.0 + tissue_he_loose_scale: float = 0.7 + tissue_he_min: float = 0.03 + stain_support_dilate: int = 25 + stain_residual_min: float = 0.055 + stain_residual_percentile: float = 84.0 + stain_ratio_min: float = 0.34 + stain_ratio_percentile: float = 78.0 + stain_candidate_sat_min: int = 32 + stain_candidate_dark_v_max: int = 95 + stain_texture_sigma: float = 3.0 + stain_keep_texture_max: float = 0.07 + stain_seed_keep_overlap: float = 0.12 + stain_note_keep_overlap: float = 0.04 + stain_residual_keep_overlap: float = 0.18 + stain_reject_texture_min: float = 0.052 + stain_edge_reject_overlap: float = 0.58 + stain_uniform_rgb_std_max: float = 30.0 + stain_min_aspect_or_smooth: float = 1.7 + stain_he_residual_reject_ratio: float = 3.2 + stain_residual_no_seed_scale: float = 1.08 + stain_ratio_no_seed_scale: float = 1.08 + stain_candidate_sat_no_seed_boost: int = 8 + stain_pen_fill_ratio_max: float = 0.60 + stain_pen_width_max: float = 26.0 + stain_pen_width_cv_max: float = 0.95 + stain_residual_blob_min_area: int = 220 + stain_he_retention_max: float = 0.62 + + # morphology + morph_kernel: int = 5 + + +@dataclass +class DetectionResult: + tissue_mask: np.ndarray + bubble_mask: np.ndarray + note_mask: np.ndarray + artifact_mask: np.ndarray + global_stain_mask: np.ndarray + contours: Dict[str, List[np.ndarray]] # keys: tissue/bubble/note/artifact/global_stain + + +class WSIProcessor: + def __init__(self, config: ProcessorConfig | None = None): + if cv2 is None: + raise ImportError( + "未安装 OpenCV(cv2),无法运行算法层。\n" + "请安装依赖:pip install opencv-python-headless\n" + f"底层错误: {_CV2_IMPORT_ERR}" + ) + self.cfg = config or ProcessorConfig() + + def _to_hsv(self, rgb: np.ndarray) -> np.ndarray: + bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) + hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) + return hsv + + @staticmethod + def _odd(k: int) -> int: + return k if (k % 2 == 1) else (k + 1) + + def _fill_holes(self, mask: np.ndarray) -> np.ndarray: + """ + 填充二值 mask 内部孔洞(mask: 0/255)。 + """ + if mask is None or mask.size == 0: + return mask + h, w = mask.shape[:2] + flood = mask.copy() + ff_mask = np.zeros((h + 2, w + 2), np.uint8) + # 从左上角背景开始 flood fill 成 255 + cv2.floodFill(flood, ff_mask, (0, 0), 255) + flood_inv = cv2.bitwise_not(flood) + filled = cv2.bitwise_or(mask, flood_inv) + return filled + + @staticmethod + def _safe_percentile(values: np.ndarray, q: float, default: float) -> float: + if values is None: + return default + arr = np.asarray(values) + if arr.size == 0: + return default + arr = arr[np.isfinite(arr)] + if arr.size == 0: + return default + return float(np.percentile(arr, q)) + + def _normalize_feature(self, feature: np.ndarray, low_q: float, high_q: float) -> np.ndarray: + low = self._safe_percentile(feature, low_q, 0.0) + high = self._safe_percentile(feature, high_q, low + 1e-6) + if high <= low + 1e-6: + return np.zeros(feature.shape, dtype=np.uint8) + scaled = (feature.astype(np.float32) - low) * (255.0 / (high - low)) + return np.clip(scaled, 0, 255).astype(np.uint8) + + def _rgb_to_od(self, rgb: np.ndarray) -> np.ndarray: + rgb_f = np.clip(rgb.astype(np.float32), 1.0, 255.0) / 255.0 + return -np.log(rgb_f) + + def _compute_stain_features(self, thumbnail_rgb: np.ndarray) -> Dict[str, np.ndarray]: + od = self._rgb_to_od(thumbnail_rgb) + + stain_matrix = np.array( + [ + [0.650, 0.072, 0.268], + [0.704, 0.990, 0.570], + [0.286, 0.105, 0.776], + ], + dtype=np.float32, + ) + stain_matrix /= np.linalg.norm(stain_matrix, axis=0, keepdims=True) + 1e-8 + + od_flat = od.reshape(-1, 3) + concentrations = np.maximum(0.0, od_flat @ np.linalg.pinv(stain_matrix.T)) + hema = concentrations[:, 0].reshape(od.shape[:2]).astype(np.float32) + eosin = concentrations[:, 1].reshape(od.shape[:2]).astype(np.float32) + he_sum = hema + eosin + + he_recon = concentrations[:, :2] @ stain_matrix[:, :2].T + residual = np.linalg.norm(np.maximum(0.0, od_flat - he_recon), axis=1).reshape(od.shape[:2]).astype( + np.float32 + ) + residual_ratio = residual / (he_sum + 0.03) + + gray = cv2.cvtColor(thumbnail_rgb, cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0 + sigma = max(0.5, float(self.cfg.stain_texture_sigma)) + local_mean = cv2.GaussianBlur(gray, (0, 0), sigmaX=sigma, sigmaY=sigma) + local_sq_mean = cv2.GaussianBlur(gray * gray, (0, 0), sigmaX=sigma, sigmaY=sigma) + texture = np.sqrt(np.maximum(local_sq_mean - local_mean * local_mean, 0.0)).astype(np.float32) + + return { + "od_sum": od.sum(axis=2).astype(np.float32), + "hema": hema, + "eosin": eosin, + "he_sum": he_sum.astype(np.float32), + "residual": residual, + "residual_ratio": residual_ratio.astype(np.float32), + "texture": texture, + } + + def _detect_tissue(self, h: np.ndarray, s: np.ndarray, v: np.ndarray, note_raw: np.ndarray) -> np.ndarray: + """ + 组织识别:输出 tissue mask(0/255)。 + + 只依赖 HSV 与配置参数;不读取/输出轮廓。 + """ + tissue = ((s > self.cfg.sat_thresh) & (v < self.cfg.val_max)).astype(np.uint8) * 255 + + # 先把“极暗伪影/笔迹”从 tissue 里剔除,避免角落黑块被当组织 + tissue = cv2.bitwise_and(tissue, cv2.bitwise_not(note_raw)) + + # 亮且几乎无色的像素(浅灰杂质)直接从组织里剔除,避免后续闭运算把它“连成线” + gray_junk = (v >= self.cfg.tissue_gray_v_min) & (s <= self.cfg.tissue_gray_s_max) + if np.any(gray_junk): + tissue[gray_junk] = 0 + + # morphology: 先闭运算连通外轮廓/填缝,再开运算去掉细碎噪声 + close_k = self._odd(max(3, int(self.cfg.tissue_close_kernel))) + open_k = self._odd(max(3, int(self.cfg.tissue_open_kernel))) + close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (close_k, close_k)) + open_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (open_k, open_k)) + tissue = cv2.morphologyEx(tissue, cv2.MORPH_CLOSE, close_kernel, iterations=1) + tissue = cv2.morphologyEx(tissue, cv2.MORPH_OPEN, open_kernel, iterations=1) + + # 填洞:让轮廓更像“组织外轮廓”,而不是组织内部碎片 + if self.cfg.tissue_fill_holes: + tissue = self._fill_holes(tissue) + + # 连通域过滤:去掉非常碎的组织片段,保留较大组织区域 + tissue = self._filter_small_components(tissue, self.cfg.tissue_min_area) + + # ----------------------------- + # 组织“细脖子”切断:避免两块组织通过浅色/脂肪区连成一大片 + # 仅切断“贴边的亮窄带”,不影响组织内部浅色区域 + # ----------------------------- + mh, mw = tissue.shape[:2] + neck_seed = ( + (tissue > 0) + & (v > self.cfg.neck_val_min) + & (s < self.cfg.neck_sat_max) + ).astype(np.uint8) * 255 + if np.any(neck_seed): + neck_k = self._odd(15) + neck_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (neck_k, neck_k)) + neck = cv2.morphologyEx(neck_seed, cv2.MORPH_OPEN, neck_kernel, iterations=1) + neck = cv2.dilate(neck, neck_kernel, iterations=1) + + num_n, labels_n, stats_n, _ = cv2.connectedComponentsWithStats(neck, connectivity=8) + for i in range(1, num_n): + x, y, w, h2, area = stats_n[i] + if area < self.cfg.neck_min_area: + continue + short_side = max(1, min(w, h2)) + if short_side > self.cfg.neck_max_thickness: + continue + touches_border = ( + x <= 1 or y <= 1 or x + w >= mw - 2 or y + h2 >= mh - 2 + ) + if not touches_border: + continue + tissue[labels_n == i] = 0 + + tissue = self._filter_small_components(tissue, self.cfg.tissue_min_area) + + # 细桥断开(保守) + bridge_k = self._odd(self.cfg.bridge_kernel) + bridge_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (bridge_k, bridge_k)) + eroded = cv2.erode(tissue, bridge_kernel, iterations=1) + tissue = cv2.dilate(eroded, bridge_kernel, iterations=1) + tissue = self._filter_small_components(tissue, self.cfg.tissue_min_area) + + # 膨胀合并邻近碎块 + if self.cfg.tissue_merge_dilate > 0: + merge_k = self._odd(self.cfg.tissue_merge_dilate) + merge_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (merge_k, merge_k)) + tissue = cv2.dilate(tissue, merge_kernel, iterations=1) + loose_sat = max(5, self.cfg.sat_thresh - 2) + loose_val = min(235, self.cfg.val_max + 12) + loose = ((s > loose_sat) & (v < loose_val)).astype(np.uint8) * 255 + loose = cv2.bitwise_and(loose, cv2.bitwise_not(note_raw)) + tissue = cv2.bitwise_and(tissue, loose) + if self.cfg.tissue_fill_holes: + tissue = self._fill_holes(tissue) + tissue = self._filter_small_components(tissue, self.cfg.tissue_min_area) + + # 最终平滑:用更大的闭运算把脂肪细小空隙“糊住” + final_k = self._odd(max(3, int(self.cfg.tissue_final_close_kernel))) + final_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (final_k, final_k)) + tissue = cv2.morphologyEx(tissue, cv2.MORPH_CLOSE, final_kernel, iterations=1) + if self.cfg.tissue_fill_holes: + tissue = self._fill_holes(tissue) + tissue = self._filter_small_components(tissue, self.cfg.tissue_min_area) + + # 最终再剔除一次浅灰无色杂质(保险) + if np.any(gray_junk): + tissue[gray_junk] = 0 + tissue = self._filter_small_components(tissue, self.cfg.tissue_min_area) + + # 细长条伪组织过滤:去掉背景中的细长线状伪影 + if self.cfg.tissue_remove_line_artifacts and np.any(tissue > 0): + # 先用一次较大的开运算得到“主体组织”,细线会被抹掉 + open_k = self._odd(max(3, int(self.cfg.tissue_line_open_kernel))) + open_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (open_k, open_k)) + tissue_opened = cv2.morphologyEx(tissue, cv2.MORPH_OPEN, open_kernel, iterations=1) + residual = cv2.bitwise_and(tissue, cv2.bitwise_not(tissue_opened)) + + # 在残差里找细长条并删除(即使它和主体组织靠得很近,也更容易被分离出来) + num_l, labels_l, stats_l, _ = cv2.connectedComponentsWithStats(residual, connectivity=8) + for i in range(1, num_l): + x, y, w, h2, area = stats_l[i] + short_side = max(1, min(w, h2)) + long_side = max(w, h2) + aspect = long_side / float(short_side) + if ( + short_side <= self.cfg.tissue_line_max_thickness + and aspect >= self.cfg.tissue_line_min_aspect + and area <= self.cfg.tissue_line_max_area + ): + residual[labels_l == i] = 0 + + tissue = cv2.bitwise_or(tissue_opened, residual) + tissue = self._filter_small_components(tissue, self.cfg.tissue_min_area) + + return tissue + + def _mask_core(self, mask: np.ndarray, margin: int) -> np.ndarray: + if mask is None or mask.size == 0 or margin <= 0: + return mask + k = self._odd(max(3, int(margin) * 2 + 1)) + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)) + return cv2.erode(mask, kernel, iterations=1) + + def _dilate_mask(self, mask: np.ndarray, size: int) -> np.ndarray: + if mask is None or mask.size == 0 or size <= 0: + return mask.copy() if mask is not None else mask + k = self._odd(max(3, int(size))) + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)) + return cv2.dilate(mask, kernel, iterations=1) + + def _limit_mask_to_core(self, mask: np.ndarray, base_mask: np.ndarray, margin: int) -> np.ndarray: + if mask is None or mask.size == 0: + return mask + if base_mask is None or base_mask.size == 0 or margin <= 0: + return mask + core = self._mask_core(base_mask, margin) + if core is None or not np.any(core > 0): + return np.zeros_like(mask) + return cv2.bitwise_and(mask, core) + + def _expand_color_mark( + self, + seed: np.ndarray, + hue_ok: np.ndarray, + sat_ok: np.ndarray, + dominant_ok: np.ndarray, + dilate_size: int, + val_ok: np.ndarray | None = None, + ) -> np.ndarray: + if seed is None or seed.size == 0 or not np.any(seed > 0): + return seed + if dilate_size <= 1: + return seed + k = self._odd(max(3, int(dilate_size))) + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)) + neighborhood = cv2.dilate(seed, kernel, iterations=1) + expanded = hue_ok & sat_ok & dominant_ok & (neighborhood > 0) + if val_ok is not None: + expanded = expanded & val_ok + return cv2.bitwise_or(seed, expanded.astype(np.uint8) * 255) + + def _suppress_edge_hugging_note(self, note: np.ndarray, tissue_mask: np.ndarray) -> np.ndarray: + if note is None or note.size == 0 or not np.any(note > 0): + return note + margin = int(self.cfg.note_edge_exclusion) + if tissue_mask is None or tissue_mask.size == 0 or margin <= 0: + return note + tissue_core = self._mask_core(tissue_mask, margin) + if tissue_core is None or not np.any(tissue_core > 0): + return note + edge_band = cv2.bitwise_and(tissue_mask, cv2.bitwise_not(tissue_core)) + num, labels, stats, _ = cv2.connectedComponentsWithStats(note, connectivity=8) + out = note.copy() + for i in range(1, num): + x, y, w, h2, area = stats[i] + if area < int(self.cfg.note_edge_min_area): + continue + comp = labels == i + edge_overlap = int(np.count_nonzero(comp & (edge_band > 0))) + if edge_overlap <= 0: + continue + overlap_ratio = edge_overlap / float(max(1, area)) + short_side = max(1, min(w, h2)) + long_side = max(w, h2) + aspect = long_side / float(short_side) + if overlap_ratio >= float(self.cfg.note_edge_overlap_ratio) and aspect < float( + self.cfg.note_edge_keep_min_aspect + ): + out[comp] = 0 + return out + + def _filter_small_compact_components( + self, + mask: np.ndarray, + area_max: int, + keep_min_aspect: float, + ) -> np.ndarray: + if mask is None or mask.size == 0 or not np.any(mask > 0): + return mask + if area_max <= 0: + return mask + num, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) + out = mask.copy() + for i in range(1, num): + x, y, w, h2, area = [int(v) for v in stats[i]] + if area > int(area_max): + continue + long_side = max(w, h2) + short_side = max(1, min(w, h2)) + aspect = long_side / float(short_side) + if aspect < float(keep_min_aspect): + out[labels == i] = 0 + return out + + def _filter_tissue_like_purple_components( + self, + mask: np.ndarray, + tissue_mask: np.ndarray, + h: np.ndarray, + s: np.ndarray, + v: np.ndarray, + ) -> np.ndarray: + if mask is None or mask.size == 0 or not np.any(mask > 0): + return mask + area_max = int(self.cfg.global_stain_purple_tissue_area_max) + if area_max <= 0: + return mask + + tissue_bool = (tissue_mask > 0) if tissue_mask is not None and tissue_mask.size > 0 else None + out = mask.copy() + num, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) + for i in range(1, num): + x, y, w, h2, area = [int(vv) for vv in stats[i]] + if area > area_max: + continue + + long_side = max(w, h2) + short_side = max(1, min(w, h2)) + aspect = long_side / float(short_side) + if aspect > float(self.cfg.global_stain_purple_tissue_max_aspect): + continue + + comp = labels == i + if tissue_bool is not None: + overlap = int(np.count_nonzero(comp & tissue_bool)) + overlap_ratio = overlap / float(max(1, area)) + if overlap_ratio < float(self.cfg.global_stain_purple_tissue_overlap_min): + continue + + mean_h = float(h[comp].mean()) + mean_s = float(s[comp].mean()) + mean_v = float(v[comp].mean()) + if ( + float(self.cfg.global_stain_purple_tissue_h_min) <= mean_h <= float(self.cfg.global_stain_purple_tissue_h_max) + and mean_s <= float(self.cfg.global_stain_purple_tissue_s_max) + and mean_v >= float(self.cfg.global_stain_purple_tissue_v_min) + ): + out[comp] = 0 + return out + + def _detect_note( + self, + thumbnail_rgb: np.ndarray, + h: np.ndarray, + s: np.ndarray, + v: np.ndarray, + tissue: np.ndarray, + mask_inside_tissue: np.ndarray, + note_raw: np.ndarray, + ) -> np.ndarray: + """ + 笔迹识别:输出 note mask(0/255),只保留组织轮廓内部。 + """ + # 复用组织 open kernel(对笔迹去噪) + open_k = self._odd(max(3, int(self.cfg.tissue_open_kernel))) + open_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (open_k, open_k)) + + note = cv2.morphologyEx(note_raw, cv2.MORPH_OPEN, open_kernel, iterations=1) + note = cv2.bitwise_and(note, mask_inside_tissue) + + # 蓝墨水:色相+饱和度+蓝通道占优,整块填实 + r = thumbnail_rgb[:, :, 0].astype(np.int16) + g = thumbnail_rgb[:, :, 1].astype(np.int16) + b = thumbnail_rgb[:, :, 2].astype(np.int16) + blue_dominant = (b > r + self.cfg.ink_blue_b_over_r) & (b > g + self.cfg.ink_blue_b_over_g) + blue_ink = ( + (h >= self.cfg.ink_blue_h_min) + & (h <= self.cfg.ink_blue_h_max) + & (s >= self.cfg.ink_blue_s_min) + & (v <= self.cfg.ink_blue_v_max) + & blue_dominant + & (mask_inside_tissue > 0) + ).astype(np.uint8) * 255 + note = cv2.bitwise_or(note, blue_ink) + + if np.any(blue_ink > 0): + blue_soft = self._expand_color_mark( + blue_ink, + (h >= self.cfg.ink_blue_expand_h_min) & (h <= self.cfg.ink_blue_expand_h_max), + s >= self.cfg.ink_blue_expand_s_min, + (b > r + self.cfg.ink_blue_expand_b_over_r) + & (b > g + self.cfg.ink_blue_expand_b_over_g), + self.cfg.ink_blue_expand_dilate, + ) + note = cv2.bitwise_or(note, blue_soft) + blue_grow = self._expand_color_mark( + blue_soft, + (h >= self.cfg.ink_blue_grow_h_min) & (h <= self.cfg.ink_blue_grow_h_max), + s >= self.cfg.ink_blue_grow_s_min, + (b >= r + self.cfg.ink_blue_grow_b_over_r) + & (b >= g + self.cfg.ink_blue_grow_b_over_g), + self.cfg.ink_blue_grow_dilate, + v <= self.cfg.ink_blue_grow_v_max, + ) + note = cv2.bitwise_or(note, blue_grow) + + # 墨迹样:深色、高饱和、细长;只保留在组织轮廓内 + ink = ( + (v < self.cfg.ink_val_max) + & (s > self.cfg.ink_sat_min) + & (tissue == 0) + ).astype(np.uint8) * 255 + ink = cv2.bitwise_and(ink, mask_inside_tissue) + num_ink, labels_ink, stats_ink, _ = cv2.connectedComponentsWithStats(ink, connectivity=8) + for i in range(1, num_ink): + x, y, w, h2, area = stats_ink[i] + if area < self.cfg.ink_min_area: + continue + long_side = max(w, h2) + short_side = max(1, min(w, h2)) + aspect = long_side / float(short_side) + if aspect < self.cfg.ink_min_aspect: + continue + note[labels_ink == i] = 255 + + # 黑/灰笔迹兜底:很暗 + 低饱和 + dark_inside_note = ( + (v < self.cfg.note_dark_val_max) + & (s <= self.cfg.note_dark_sat_max) + & (mask_inside_tissue > 0) + ) + note = cv2.bitwise_or(note, (dark_inside_note.astype(np.uint8) * 255)) + + # 过滤细胞核大小的孤立笔迹点 + very_dark_inside = ( + (v < 35) + & (s <= min(80, self.cfg.note_dark_sat_max + 20)) + & (mask_inside_tissue > 0) + ).astype(np.uint8) * 255 + very_dark_inside = cv2.morphologyEx(very_dark_inside, cv2.MORPH_OPEN, open_kernel, iterations=1) + very_dark_inside = self._filter_small_components(very_dark_inside, max(40, self.cfg.note_min_area)) + note = cv2.bitwise_or(note, very_dark_inside) + + dark_ink_candidate = ( + (v > 35) + & (v < 70) + & (s < 55) + & (mask_inside_tissue > 0) + ).astype(np.uint8) * 255 + dark_ink_candidate = cv2.morphologyEx(dark_ink_candidate, cv2.MORPH_OPEN, open_kernel, iterations=1) + num_dark, labels_dark, stats_dark, _ = cv2.connectedComponentsWithStats( + dark_ink_candidate, connectivity=8 + ) + for i in range(1, num_dark): + x, y, w, h2, area = stats_dark[i] + if area < 70: + continue + long_side = max(w, h2) + short_side = max(1, min(w, h2)) + aspect = long_side / float(short_side) + if aspect >= 2.6: + note[labels_dark == i] = 255 + + close_k = self._odd(max(3, int(self.cfg.note_close_kernel))) + close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (close_k, close_k)) + note = cv2.morphologyEx(note, cv2.MORPH_CLOSE, close_kernel, iterations=1) + note = cv2.bitwise_and(note, mask_inside_tissue) + note = self._suppress_edge_hugging_note(note, tissue) + note = self._filter_small_components(note, self.cfg.note_min_area) + return note + + def _detect_artifact( + self, + thumbnail_rgb: np.ndarray, + h: np.ndarray, + s: np.ndarray, + v: np.ndarray, + mask_inside_tissue: np.ndarray, + note: np.ndarray, + ) -> np.ndarray: + """ + 伪影识别:输出 artifact mask(0/255)。 + + 当前策略:只在组织轮廓内识别“近纯白空洞/裂隙”(背景);可选折叠剔除; + 并做深紫高密度组织保护与笔迹优先级。 + """ + artifact = np.zeros_like(mask_inside_tissue) + if not (self.cfg.enable_artifact and np.any(mask_inside_tissue > 0)): + return artifact + + bg_candidate = ( + (mask_inside_tissue > 0) + & (v >= self.cfg.artifact_bg_v_min) + & (s <= self.cfg.artifact_bg_s_max) + ).astype(np.uint8) * 255 + art_open_k = self._odd(max(3, self.cfg.artifact_open_kernel)) + art_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (art_open_k, art_open_k)) + artifact = cv2.morphologyEx(bg_candidate, cv2.MORPH_OPEN, art_kernel, iterations=1) + artifact = self._filter_small_components(artifact, self.cfg.artifact_min_area) + + # 折叠:可选剔除 + if self.cfg.enable_folding_artifact and not self.cfg.treat_folding_as_tissue: + bgr = cv2.cvtColor(thumbnail_rgb, cv2.COLOR_RGB2BGR) + lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB) + l_ch, a_ch, _b_ch = cv2.split(lab) + folding_candidate = ( + (mask_inside_tissue > 0) + & (l_ch < self.cfg.folding_L_max) + & (a_ch.astype(np.int32) > self.cfg.folding_a_min) + ).astype(np.uint8) * 255 + num_f, labels_f, stats_f, _ = cv2.connectedComponentsWithStats( + folding_candidate, connectivity=8 + ) + for i in range(1, num_f): + x, y, w, h2, area = stats_f[i] + if area < self.cfg.folding_min_area: + continue + long_side = max(w, h2) + short_side = max(1, min(w, h2)) + if long_side / short_side < self.cfg.folding_min_aspect: + continue + artifact[labels_f == i] = 255 + + # 深紫/蓝紫高密度组织保护:绝不标为伪影 + purple_dense = ( + (mask_inside_tissue > 0) + & (h >= self.cfg.artifact_purple_h_min) + & (h <= self.cfg.artifact_purple_h_max) + & (s >= self.cfg.artifact_purple_s_min) + & (v <= self.cfg.artifact_purple_v_max) + ).astype(np.uint8) * 255 + artifact = cv2.bitwise_and(artifact, cv2.bitwise_not(purple_dense)) + + # 笔迹优先:已判为 Note 的不再标为 Artifact + artifact = cv2.bitwise_and(artifact, cv2.bitwise_not(note)) + return artifact + + def _detect_stain_internal_stats( + self, + tissue_mask: np.ndarray, + h: np.ndarray, + s: np.ndarray, + v: np.ndarray, + ) -> np.ndarray: + if tissue_mask is None or tissue_mask.size == 0: + return np.zeros_like(tissue_mask) + + hue_red = (h <= self.cfg.stain_internal_h_red_hmax) | (h >= self.cfg.stain_internal_h_red_hmin) + hue_purple = (h >= self.cfg.stain_internal_h_purple_min) & (h <= self.cfg.stain_internal_h_purple_max) + hue_keep = hue_red | hue_purple + + tissue_bin = (tissue_mask > 0).astype(np.uint8) + tissue_core = self._mask_core(tissue_bin * 255, self.cfg.stain_tissue_edge_exclusion) + tissue_core_bool = tissue_core > 0 if tissue_core is not None else tissue_bin > 0 + num, labels, stats, _ = cv2.connectedComponentsWithStats(tissue_bin, connectivity=8) + if num <= 1: + return np.zeros_like(tissue_mask) + + s_norm = s.astype(np.float32) / 255.0 + v_norm = v.astype(np.float32) / 255.0 + alpha = float(self.cfg.stain_internal_score_alpha) + score = alpha * s_norm + (1.0 - alpha) * (1.0 - v_norm) + + out = np.zeros_like(tissue_mask) + k = float(self.cfg.stain_internal_score_k) + for i in range(1, num): + area = int(stats[i, cv2.CC_STAT_AREA]) + if area < max(1, int(self.cfg.stain_internal_min_area)): + continue + comp = labels == i + if self.cfg.stain_tissue_edge_exclusion > 0: + comp = comp & tissue_core_bool + if not np.any(comp): + continue + + comp_scores = score[comp] + thr = float(np.mean(comp_scores)) + k * float(np.std(comp_scores)) + comp_stain = ( + comp + & hue_keep + & (score >= thr) + & (v <= self.cfg.stain_internal_v_max) + & (s >= self.cfg.stain_internal_s_min) + ) + out[comp_stain] = 255 + + if self.cfg.stain_internal_open_kernel > 0 and np.any(out > 0): + k2 = self._odd(max(3, int(self.cfg.stain_internal_open_kernel))) + kern = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k2, k2)) + out = cv2.morphologyEx(out, cv2.MORPH_OPEN, kern, iterations=1) + + return self._filter_small_components(out, int(self.cfg.stain_internal_min_area)) + + def _detect_global_stain( + self, + thumbnail_rgb: np.ndarray, + h: np.ndarray, + s: np.ndarray, + v: np.ndarray, + note_raw: np.ndarray, + tissue_mask: np.ndarray, + note: np.ndarray, + ) -> np.ndarray: + close_k = self._odd(max(3, int(self.cfg.note_close_kernel))) + close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (close_k, close_k)) + + _ = note_raw + global_stain = note.copy() + + ink_global = ((v < self.cfg.ink_val_max) & (s > self.cfg.ink_sat_min)).astype(np.uint8) * 255 + num_ink, labels_ink, stats_ink, _ = cv2.connectedComponentsWithStats(ink_global, connectivity=8) + for i in range(1, num_ink): + x, y, w, h2, area = stats_ink[i] + if area < self.cfg.ink_min_area: + continue + long_side = max(w, h2) + short_side = max(1, min(w, h2)) + aspect = long_side / float(short_side) + if aspect < self.cfg.ink_min_aspect: + continue + global_stain[labels_ink == i] = 255 + + r = thumbnail_rgb[:, :, 0].astype(np.int16) + g = thumbnail_rgb[:, :, 1].astype(np.int16) + b = thumbnail_rgb[:, :, 2].astype(np.int16) + + blue_seed = ( + (h >= self.cfg.ink_blue_h_min) + & (h <= self.cfg.ink_blue_h_max) + & (s >= self.cfg.ink_blue_s_min) + & (v <= self.cfg.ink_blue_v_max) + & (b > r + self.cfg.ink_blue_b_over_r) + & (b > g + self.cfg.ink_blue_b_over_g) + ).astype(np.uint8) * 255 + if np.any(blue_seed > 0): + blue_soft = self._expand_color_mark( + blue_seed, + (h >= self.cfg.ink_blue_expand_h_min) & (h <= self.cfg.ink_blue_expand_h_max), + s >= self.cfg.ink_blue_expand_s_min, + (b > r + self.cfg.ink_blue_expand_b_over_r) + & (b > g + self.cfg.ink_blue_expand_b_over_g), + self.cfg.ink_blue_expand_dilate, + ) + blue_grow = self._expand_color_mark( + blue_soft, + (h >= self.cfg.ink_blue_grow_h_min) & (h <= self.cfg.ink_blue_grow_h_max), + s >= self.cfg.ink_blue_grow_s_min, + (b >= r + self.cfg.ink_blue_grow_b_over_r) + & (b >= g + self.cfg.ink_blue_grow_b_over_g), + self.cfg.ink_blue_grow_dilate, + v <= self.cfg.ink_blue_grow_v_max, + ) + global_stain = cv2.bitwise_or(global_stain, blue_grow) + + green_seed = ( + (h >= self.cfg.global_stain_green_h_min) + & (h <= self.cfg.global_stain_green_h_max) + & (s >= self.cfg.global_stain_green_s_min) + & (v <= self.cfg.global_stain_green_v_max) + & (g >= r + self.cfg.global_stain_green_g_over_r) + & (g >= b + self.cfg.global_stain_green_g_over_b) + ).astype(np.uint8) * 255 + if np.any(green_seed > 0): + green_seed = cv2.morphologyEx(green_seed, cv2.MORPH_CLOSE, close_kernel, iterations=1) + green_soft = self._expand_color_mark( + green_seed, + (h >= self.cfg.global_stain_green_expand_h_min) + & (h <= self.cfg.global_stain_green_expand_h_max), + s >= self.cfg.global_stain_green_expand_s_min, + (g >= r + self.cfg.global_stain_green_expand_g_over_r) + & (g >= b + self.cfg.global_stain_green_expand_g_over_b), + self.cfg.global_stain_green_expand_dilate, + ) + green_soft = self._expand_color_mark( + green_soft, + (h >= self.cfg.global_stain_green_grow_h_min) + & (h <= self.cfg.global_stain_green_grow_h_max), + s >= self.cfg.global_stain_green_grow_s_min, + (g >= r + self.cfg.global_stain_green_grow_g_over_r) + & (g >= b + self.cfg.global_stain_green_grow_g_over_b), + self.cfg.global_stain_green_grow_dilate, + ) + green_close_k = self._odd(max(3, int(self.cfg.global_stain_green_close_kernel))) + green_close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (green_close_k, green_close_k)) + green_soft = cv2.morphologyEx(green_soft, cv2.MORPH_CLOSE, green_close_kernel, iterations=1) + green_soft = self._filter_small_components(green_soft, self.cfg.global_stain_green_min_area) + global_stain = cv2.bitwise_or(global_stain, green_soft) + + red_seed = ( + ((h <= 10) | (h >= 160)) + & (s >= self.cfg.global_stain_red_s_min) + & (v <= self.cfg.global_stain_red_v_max) + & (r >= g + self.cfg.global_stain_red_r_over_g) + & (r >= b + self.cfg.global_stain_red_r_over_b) + ).astype(np.uint8) * 255 + if np.any(red_seed > 0): + red_seed = cv2.morphologyEx(red_seed, cv2.MORPH_CLOSE, close_kernel, iterations=1) + red_seed = self._expand_color_mark( + red_seed, + ((h <= 12) | (h >= 158)), + s >= self.cfg.global_stain_red_expand_s_min, + (r >= g + self.cfg.global_stain_red_expand_r_over_g) + & (r >= b + self.cfg.global_stain_red_expand_r_over_b), + self.cfg.global_stain_red_expand_dilate, + ) + red_seed = self._filter_small_compact_components( + red_seed, + area_max=self.cfg.global_stain_red_compact_area_max, + keep_min_aspect=self.cfg.global_stain_red_compact_keep_min_aspect, + ) + red_seed = self._filter_small_components(red_seed, self.cfg.global_stain_red_min_area) + global_stain = cv2.bitwise_or(global_stain, red_seed) + + dark_open_k = self._odd(max(3, int(self.cfg.global_stain_dark_open_kernel))) + dark_open_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dark_open_k, dark_open_k)) + if np.any(note > 0): + dark_expand_k = self._odd(max(3, int(self.cfg.global_stain_dark_expand_dilate))) + dark_expand_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dark_expand_k, dark_expand_k)) + note_neighbor = cv2.dilate(note, dark_expand_kernel, iterations=1) + dark_inside = ( + (v <= self.cfg.global_stain_dark_v_max) + & (s <= self.cfg.global_stain_dark_s_max) + & (tissue_mask > 0) + & (note_neighbor > 0) + ).astype(np.uint8) * 255 + dark_inside = cv2.morphologyEx(dark_inside, cv2.MORPH_CLOSE, dark_open_kernel, iterations=1) + dark_inside = self._filter_small_components(dark_inside, self.cfg.global_stain_dark_min_area) + global_stain = cv2.bitwise_or(global_stain, dark_inside) + + global_stain = cv2.bitwise_or(global_stain, note) + + global_stain = cv2.morphologyEx(global_stain, cv2.MORPH_CLOSE, close_kernel, iterations=1) + global_stain = self._filter_small_compact_components( + global_stain, + area_max=self.cfg.global_stain_compact_area_max, + keep_min_aspect=self.cfg.global_stain_compact_keep_min_aspect, + ) + global_stain = self._filter_tissue_like_purple_components( + global_stain, + tissue_mask=tissue_mask, + h=h, + s=s, + v=v, + ) + return self._filter_small_components(global_stain, self.cfg.global_stain_min_area) + + def _detect_tissue( + self, + h: np.ndarray, + s: np.ndarray, + v: np.ndarray, + note_raw: np.ndarray, + he_sum: np.ndarray, + od_sum: np.ndarray, + ) -> np.ndarray: + he_score = self._normalize_feature(he_sum, low_q=1.0, high_q=99.5) + od_score = self._normalize_feature(od_sum, low_q=1.0, high_q=99.5) + tissue_score = cv2.addWeighted( + he_score, + float(self.cfg.tissue_score_he_weight), + od_score, + float(self.cfg.tissue_score_od_weight), + 0.0, + ) + blur_sigma = max(0.5, float(self.cfg.tissue_score_blur_sigma)) + tissue_score = cv2.GaussianBlur(tissue_score, (0, 0), sigmaX=blur_sigma, sigmaY=blur_sigma) + _, tissue_seed = cv2.threshold(tissue_score, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + + loose_sat = max(3, self.cfg.sat_thresh - 4) + loose_val = min(245, self.cfg.val_max + 18) + loose_hsv = (s > loose_sat) & (v < loose_val) + foreground = tissue_seed > 0 + he_loose_thr = max( + float(self.cfg.tissue_he_min), + self._safe_percentile( + he_sum[foreground], + float(self.cfg.tissue_he_loose_percentile), + float(self.cfg.tissue_he_min), + ) + * float(self.cfg.tissue_he_loose_scale), + ) + tissue = (foreground & (loose_hsv | (he_sum >= he_loose_thr))).astype(np.uint8) * 255 + tissue = cv2.bitwise_and(tissue, cv2.bitwise_not(note_raw)) + + gray_junk = (v >= self.cfg.tissue_gray_v_min) & (s <= self.cfg.tissue_gray_s_max) + if np.any(gray_junk): + tissue[gray_junk] = 0 + + close_k = self._odd(max(3, int(self.cfg.tissue_close_kernel))) + open_k = self._odd(max(3, int(self.cfg.tissue_open_kernel))) + close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (close_k, close_k)) + open_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (open_k, open_k)) + tissue = cv2.morphologyEx(tissue, cv2.MORPH_CLOSE, close_kernel, iterations=1) + tissue = cv2.morphologyEx(tissue, cv2.MORPH_OPEN, open_kernel, iterations=1) + + if self.cfg.tissue_fill_holes: + tissue = self._fill_holes(tissue) + tissue = self._filter_small_components(tissue, self.cfg.tissue_min_area) + + mh, mw = tissue.shape[:2] + neck_seed = ( + (tissue > 0) + & (v > self.cfg.neck_val_min) + & (s < self.cfg.neck_sat_max) + ).astype(np.uint8) * 255 + if np.any(neck_seed): + neck_k = self._odd(15) + neck_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (neck_k, neck_k)) + neck = cv2.morphologyEx(neck_seed, cv2.MORPH_OPEN, neck_kernel, iterations=1) + neck = cv2.dilate(neck, neck_kernel, iterations=1) + + num_n, labels_n, stats_n, _ = cv2.connectedComponentsWithStats(neck, connectivity=8) + for i in range(1, num_n): + x, y, w, h2, area = [int(vv) for vv in stats_n[i]] + if area < int(self.cfg.neck_min_area): + continue + short_side = max(1, min(w, h2)) + if short_side > int(self.cfg.neck_max_thickness): + continue + touches_border = x <= 1 or y <= 1 or x + w >= mw - 2 or y + h2 >= mh - 2 + if not touches_border: + continue + tissue[labels_n == i] = 0 + + tissue = self._filter_small_components(tissue, self.cfg.tissue_min_area) + + bridge_k = self._odd(int(self.cfg.bridge_kernel)) + bridge_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (bridge_k, bridge_k)) + eroded = cv2.erode(tissue, bridge_kernel, iterations=1) + tissue = cv2.dilate(eroded, bridge_kernel, iterations=1) + tissue = self._filter_small_components(tissue, self.cfg.tissue_min_area) + + if self.cfg.tissue_merge_dilate > 0: + merge_k = self._odd(int(self.cfg.tissue_merge_dilate)) + merge_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (merge_k, merge_k)) + tissue = cv2.dilate(tissue, merge_kernel, iterations=1) + loose = (loose_hsv | (he_sum >= max(float(self.cfg.tissue_he_min), he_loose_thr * 0.8))).astype( + np.uint8 + ) * 255 + loose = cv2.bitwise_and(loose, cv2.bitwise_not(note_raw)) + tissue = cv2.bitwise_and(tissue, loose) + if self.cfg.tissue_fill_holes: + tissue = self._fill_holes(tissue) + tissue = self._filter_small_components(tissue, self.cfg.tissue_min_area) + + final_k = self._odd(max(3, int(self.cfg.tissue_final_close_kernel))) + final_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (final_k, final_k)) + tissue = cv2.morphologyEx(tissue, cv2.MORPH_CLOSE, final_kernel, iterations=1) + if self.cfg.tissue_fill_holes: + tissue = self._fill_holes(tissue) + tissue = self._filter_small_components(tissue, self.cfg.tissue_min_area) + + if np.any(gray_junk): + tissue[gray_junk] = 0 + tissue = self._filter_small_components(tissue, self.cfg.tissue_min_area) + + if self.cfg.tissue_remove_line_artifacts and np.any(tissue > 0): + open_k = self._odd(max(3, int(self.cfg.tissue_line_open_kernel))) + open_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (open_k, open_k)) + tissue_opened = cv2.morphologyEx(tissue, cv2.MORPH_OPEN, open_kernel, iterations=1) + residual_mask = cv2.bitwise_and(tissue, cv2.bitwise_not(tissue_opened)) + + num_l, labels_l, stats_l, _ = cv2.connectedComponentsWithStats(residual_mask, connectivity=8) + for i in range(1, num_l): + _, _, w, h2, area = [int(vv) for vv in stats_l[i]] + short_side = max(1, min(w, h2)) + long_side = max(w, h2) + aspect = long_side / float(short_side) + if ( + short_side <= int(self.cfg.tissue_line_max_thickness) + and aspect >= float(self.cfg.tissue_line_min_aspect) + and area <= int(self.cfg.tissue_line_max_area) + ): + residual_mask[labels_l == i] = 0 + + tissue = cv2.bitwise_or(tissue_opened, residual_mask) + tissue = self._filter_small_components(tissue, self.cfg.tissue_min_area) + + return tissue + + def _component_width_stats(self, component_mask: np.ndarray) -> Tuple[float, float]: + if component_mask is None or component_mask.size == 0 or not np.any(component_mask > 0): + return 0.0, 1.0 + + dist = cv2.distanceTransform(component_mask, cv2.DIST_L2, 5) + widths = dist[component_mask > 0] + widths = widths[widths > 0.5] * 2.0 + if widths.size == 0: + return 0.0, 1.0 + + mean_width = float(widths.mean()) + width_cv = float(widths.std() / max(mean_width, 1e-6)) + return mean_width, width_cv + + def _classify_stain_components( + self, + candidate: np.ndarray, + thumbnail_rgb: np.ndarray, + tissue_mask: np.ndarray, + color_seed: np.ndarray, + note_seed: np.ndarray, + residual_seed: np.ndarray, + s: np.ndarray, + v: np.ndarray, + he_sum: np.ndarray, + residual: np.ndarray, + residual_ratio: np.ndarray, + texture: np.ndarray, + residual_thr: float, + ratio_thr: float, + ) -> np.ndarray: + if candidate is None or candidate.size == 0 or not np.any(candidate > 0): + return np.zeros_like(candidate) + + tissue_bool = tissue_mask > 0 + tissue_core = self._mask_core(tissue_mask, self.cfg.stain_tissue_edge_exclusion) + if tissue_core is None or not np.any(tissue_core > 0): + edge_band_bool = tissue_bool + else: + edge_band_bool = cv2.bitwise_and(tissue_mask, cv2.bitwise_not(tissue_core)) > 0 + + color_seed_bool = color_seed > 0 + note_seed_bool = note_seed > 0 + residual_seed_bool = residual_seed > 0 + + num, labels, stats, _ = cv2.connectedComponentsWithStats(candidate, connectivity=8) + out = np.zeros_like(candidate) + for i in range(1, num): + x, y, w, h2, area = [int(vv) for vv in stats[i]] + if area <= 0: + continue + + comp = labels == i + overlap_tissue = int(np.count_nonzero(comp & tissue_bool)) / float(area) + edge_overlap = int(np.count_nonzero(comp & edge_band_bool)) / float(area) + color_overlap = int(np.count_nonzero(comp & color_seed_bool)) / float(area) + note_overlap = int(np.count_nonzero(comp & note_seed_bool)) / float(area) + residual_overlap = int(np.count_nonzero(comp & residual_seed_bool)) / float(area) + + mean_he = float(he_sum[comp].mean()) + mean_residual = float(residual[comp].mean()) + mean_ratio = float(residual_ratio[comp].mean()) + mean_texture = float(texture[comp].mean()) + mean_sat = float(s[comp].mean()) + mean_val = float(v[comp].mean()) + + rgb_vals = thumbnail_rgb[comp].astype(np.float32) + rgb_std = float(rgb_vals.std(axis=0).mean()) if rgb_vals.size > 0 else 0.0 + + long_side = max(w, h2) + short_side = max(1, min(w, h2)) + aspect = long_side / float(short_side) + fill_ratio = area / float(max(1, w * h2)) + comp_crop = (labels[y : y + h2, x : x + w] == i).astype(np.uint8) * 255 + mean_width, width_cv = self._component_width_stats(comp_crop) + he_retention = mean_he / max(mean_he + mean_residual, 1e-6) + smooth = mean_texture <= float(self.cfg.stain_keep_texture_max) + uniform = rgb_std <= float(self.cfg.stain_uniform_rgb_std_max) + elongated = aspect >= float(self.cfg.stain_min_aspect_or_smooth) + pen_like_geometry = ( + elongated + and mean_width <= float(self.cfg.stain_pen_width_max) + and width_cv <= float(self.cfg.stain_pen_width_cv_max) + ) or ( + smooth + and uniform + and fill_ratio <= float(self.cfg.stain_pen_fill_ratio_max) + and mean_width <= float(self.cfg.stain_pen_width_max) * 1.35 + ) + strong_seed = color_overlap >= float(self.cfg.stain_seed_keep_overlap) + strong_note = note_overlap >= float(self.cfg.stain_note_keep_overlap) + residual_only = not strong_seed and not strong_note + strong_residual = ( + mean_residual >= max(float(self.cfg.stain_residual_min), residual_thr * 0.8) + and mean_ratio >= max(float(self.cfg.stain_ratio_min), ratio_thr * 0.75) + and mean_sat >= float(self.cfg.stain_candidate_sat_min) + and mean_texture < float(self.cfg.stain_keep_texture_max) * 1.05 + ) + residual_dominant = ( + mean_ratio >= max(float(self.cfg.stain_ratio_min), ratio_thr * 0.9) + and he_retention <= float(self.cfg.stain_he_retention_max) + ) + + tissue_like = ( + overlap_tissue >= 0.55 + and mean_he > max(mean_residual * float(self.cfg.stain_he_residual_reject_ratio), 0.12) + and mean_texture >= max(0.035, float(self.cfg.stain_reject_texture_min) * 0.75) + and not elongated + and color_overlap < 0.20 + and note_overlap < 0.15 + ) + edge_blob = ( + edge_overlap >= float(self.cfg.stain_edge_reject_overlap) + and aspect < 3.0 + and mean_texture >= max(0.035, float(self.cfg.stain_reject_texture_min) * 0.75) + and note_overlap < 0.12 + and mean_residual < residual_thr * 1.35 + ) + tiny_unstable = ( + area < int(self.cfg.global_stain_min_area) + and not strong_seed + and not strong_note + and residual_overlap < float(self.cfg.stain_residual_keep_overlap) + and aspect < 2.0 + ) + residual_blob = ( + residual_only + and area < max( + int(self.cfg.stain_residual_blob_min_area), + int(self.cfg.global_stain_min_area) * 2, + ) + and not pen_like_geometry + and he_retention > float(self.cfg.stain_he_retention_max) * 0.92 + ) + broad_tissue_blob = ( + overlap_tissue >= 0.65 + and fill_ratio >= 0.48 + and not elongated + and he_retention > float(self.cfg.stain_he_retention_max) + and mean_texture >= max(0.03, float(self.cfg.stain_reject_texture_min) * 0.7) + and color_overlap < 0.25 + and note_overlap < 0.15 + ) + if tiny_unstable or tissue_like or edge_blob or residual_blob or broad_tissue_blob: + continue + + keep = False + if strong_seed and (smooth or uniform or pen_like_geometry): + keep = True + elif strong_note and (pen_like_geometry or smooth or mean_val <= float(self.cfg.note_dark_val_max)): + keep = True + elif strong_residual and residual_dominant and ( + pen_like_geometry or residual_overlap >= float(self.cfg.stain_residual_keep_overlap) + ): + keep = True + elif ( + residual_only + and residual_overlap >= float(self.cfg.stain_residual_keep_overlap) + and mean_sat >= float(self.cfg.stain_candidate_sat_min) + and residual_dominant + and pen_like_geometry + ): + keep = True + + if keep: + out[comp] = 255 + + return out + + def _detect_global_stain( + self, + thumbnail_rgb: np.ndarray, + h: np.ndarray, + s: np.ndarray, + v: np.ndarray, + note_raw: np.ndarray, + tissue_mask: np.ndarray, + note: np.ndarray, + he_sum: np.ndarray, + residual: np.ndarray, + residual_ratio: np.ndarray, + texture: np.ndarray, + ) -> np.ndarray: + _ = note_raw + close_k = self._odd(max(3, int(self.cfg.note_close_kernel))) + close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (close_k, close_k)) + support_mask = self._dilate_mask(tissue_mask, int(self.cfg.stain_support_dilate)) + if support_mask is None or support_mask.size == 0 or not np.any(support_mask > 0): + return np.zeros_like(tissue_mask) + + r = thumbnail_rgb[:, :, 0].astype(np.int16) + g = thumbnail_rgb[:, :, 1].astype(np.int16) + b = thumbnail_rgb[:, :, 2].astype(np.int16) + + note_seed = np.zeros_like(tissue_mask) + ink_global = ((v < self.cfg.ink_val_max) & (s > self.cfg.ink_sat_min)).astype(np.uint8) * 255 + num_ink, labels_ink, stats_ink, _ = cv2.connectedComponentsWithStats(ink_global, connectivity=8) + for i in range(1, num_ink): + _, _, w, h2, area = [int(vv) for vv in stats_ink[i]] + if area < int(self.cfg.ink_min_area): + continue + long_side = max(w, h2) + short_side = max(1, min(w, h2)) + if long_side / float(short_side) < float(self.cfg.ink_min_aspect): + continue + note_seed[labels_ink == i] = 255 + + blue_seed = ( + (h >= self.cfg.ink_blue_h_min) + & (h <= self.cfg.ink_blue_h_max) + & (s >= self.cfg.ink_blue_s_min) + & (v <= self.cfg.ink_blue_v_max) + & (b > r + self.cfg.ink_blue_b_over_r) + & (b > g + self.cfg.ink_blue_b_over_g) + ).astype(np.uint8) * 255 + if np.any(blue_seed > 0): + blue_seed = cv2.bitwise_and(blue_seed, support_mask) + blue_soft = self._expand_color_mark( + blue_seed, + (h >= self.cfg.ink_blue_expand_h_min) & (h <= self.cfg.ink_blue_expand_h_max), + s >= self.cfg.ink_blue_expand_s_min, + (b > r + self.cfg.ink_blue_expand_b_over_r) & (b > g + self.cfg.ink_blue_expand_b_over_g), + self.cfg.ink_blue_expand_dilate, + ) + blue_seed = self._expand_color_mark( + blue_soft, + (h >= self.cfg.ink_blue_grow_h_min) & (h <= self.cfg.ink_blue_grow_h_max), + s >= self.cfg.ink_blue_grow_s_min, + (b >= r + self.cfg.ink_blue_grow_b_over_r) & (b >= g + self.cfg.ink_blue_grow_b_over_g), + self.cfg.ink_blue_grow_dilate, + v <= self.cfg.ink_blue_grow_v_max, + ) + else: + blue_seed = np.zeros_like(tissue_mask) + + green_seed = ( + (h >= self.cfg.global_stain_green_h_min) + & (h <= self.cfg.global_stain_green_h_max) + & (s >= self.cfg.global_stain_green_s_min) + & (v <= self.cfg.global_stain_green_v_max) + & (g >= r + self.cfg.global_stain_green_g_over_r) + & (g >= b + self.cfg.global_stain_green_g_over_b) + ).astype(np.uint8) * 255 + if np.any(green_seed > 0): + green_seed = cv2.bitwise_and(green_seed, support_mask) + green_seed = cv2.morphologyEx(green_seed, cv2.MORPH_CLOSE, close_kernel, iterations=1) + green_seed = self._expand_color_mark( + green_seed, + (h >= self.cfg.global_stain_green_expand_h_min) & (h <= self.cfg.global_stain_green_expand_h_max), + s >= self.cfg.global_stain_green_expand_s_min, + (g >= r + self.cfg.global_stain_green_expand_g_over_r) + & (g >= b + self.cfg.global_stain_green_expand_g_over_b), + self.cfg.global_stain_green_expand_dilate, + ) + green_seed = self._expand_color_mark( + green_seed, + (h >= self.cfg.global_stain_green_grow_h_min) & (h <= self.cfg.global_stain_green_grow_h_max), + s >= self.cfg.global_stain_green_grow_s_min, + (g >= r + self.cfg.global_stain_green_grow_g_over_r) + & (g >= b + self.cfg.global_stain_green_grow_g_over_b), + self.cfg.global_stain_green_grow_dilate, + ) + green_close_k = self._odd(max(3, int(self.cfg.global_stain_green_close_kernel))) + green_close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (green_close_k, green_close_k)) + green_seed = cv2.morphologyEx(green_seed, cv2.MORPH_CLOSE, green_close_kernel, iterations=1) + green_seed = self._filter_small_components(green_seed, self.cfg.global_stain_green_min_area) + else: + green_seed = np.zeros_like(tissue_mask) + + red_seed = ( + ((h <= 10) | (h >= 160)) + & (s >= self.cfg.global_stain_red_s_min) + & (v <= self.cfg.global_stain_red_v_max) + & (r >= g + self.cfg.global_stain_red_r_over_g) + & (r >= b + self.cfg.global_stain_red_r_over_b) + ).astype(np.uint8) * 255 + if np.any(red_seed > 0): + red_seed = cv2.bitwise_and(red_seed, support_mask) + red_seed = cv2.morphologyEx(red_seed, cv2.MORPH_CLOSE, close_kernel, iterations=1) + red_seed = self._expand_color_mark( + red_seed, + ((h <= 12) | (h >= 158)), + s >= self.cfg.global_stain_red_expand_s_min, + (r >= g + self.cfg.global_stain_red_expand_r_over_g) & (r >= b + self.cfg.global_stain_red_expand_r_over_b), + self.cfg.global_stain_red_expand_dilate, + ) + red_seed = self._filter_small_compact_components( + red_seed, + area_max=self.cfg.global_stain_red_compact_area_max, + keep_min_aspect=self.cfg.global_stain_red_compact_keep_min_aspect, + ) + red_seed = self._filter_small_components(red_seed, self.cfg.global_stain_red_min_area) + else: + red_seed = np.zeros_like(tissue_mask) + + color_seed = cv2.bitwise_or(blue_seed, green_seed) + color_seed = cv2.bitwise_or(color_seed, red_seed) + color_seed = cv2.bitwise_and(color_seed, support_mask) + + dark_inside = np.zeros_like(tissue_mask) + dark_open_k = self._odd(max(3, int(self.cfg.global_stain_dark_open_kernel))) + dark_open_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dark_open_k, dark_open_k)) + if np.any(note > 0): + dark_expand_k = self._odd(max(3, int(self.cfg.global_stain_dark_expand_dilate))) + dark_expand_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dark_expand_k, dark_expand_k)) + note_neighbor = cv2.dilate(note, dark_expand_kernel, iterations=1) + dark_inside = ( + (v <= self.cfg.global_stain_dark_v_max) + & (s <= self.cfg.global_stain_dark_s_max) + & (tissue_mask > 0) + & (note_neighbor > 0) + ).astype(np.uint8) * 255 + dark_inside = cv2.morphologyEx(dark_inside, cv2.MORPH_CLOSE, dark_open_kernel, iterations=1) + dark_inside = self._filter_small_components(dark_inside, self.cfg.global_stain_dark_min_area) + + note_seed = cv2.bitwise_or(note_seed, note) + note_seed = cv2.bitwise_or(note_seed, dark_inside) + note_seed = cv2.bitwise_and(note_seed, support_mask) + + seed_support = cv2.bitwise_or(color_seed, note_seed) + has_anchor_seed = np.any(seed_support > 0) + if has_anchor_seed: + seed_support = self._dilate_mask(seed_support, max(int(self.cfg.stain_support_dilate), 31)) + residual_support_bool = seed_support > 0 + else: + tissue_core = self._mask_core(tissue_mask, max(1, int(self.cfg.stain_tissue_edge_exclusion))) + if tissue_core is not None and np.any(tissue_core > 0): + residual_support_bool = tissue_core > 0 + else: + residual_support_bool = tissue_mask > 0 + + residual_scale = float(self.cfg.stain_residual_no_seed_scale) if not has_anchor_seed else 1.0 + ratio_scale = float(self.cfg.stain_ratio_no_seed_scale) if not has_anchor_seed else 1.0 + residual_thr = max( + float(self.cfg.stain_residual_min), + self._safe_percentile( + residual[residual_support_bool], + float(self.cfg.stain_residual_percentile), + float(self.cfg.stain_residual_min), + ), + ) * residual_scale + ratio_thr = max( + float(self.cfg.stain_ratio_min), + self._safe_percentile( + residual_ratio[residual_support_bool], + float(self.cfg.stain_ratio_percentile), + float(self.cfg.stain_ratio_min), + ), + ) * ratio_scale + sat_min = int(self.cfg.stain_candidate_sat_min) + if not has_anchor_seed: + sat_min = min(255, sat_min + int(self.cfg.stain_candidate_sat_no_seed_boost)) + texture_limit = float(self.cfg.stain_keep_texture_max) * (0.95 if has_anchor_seed else 0.82) + residual_seed = ( + residual_support_bool + & (residual >= residual_thr) + & (residual_ratio >= ratio_thr) + & ((s >= sat_min) | (v <= self.cfg.stain_candidate_dark_v_max)) + & (texture <= texture_limit) + ).astype(np.uint8) * 255 + if np.any(residual_seed > 0): + residual_seed = cv2.morphologyEx(residual_seed, cv2.MORPH_CLOSE, close_kernel, iterations=1) + + global_stain = cv2.bitwise_or(color_seed, note_seed) + global_stain = cv2.bitwise_or(global_stain, residual_seed) + global_stain = cv2.bitwise_and(global_stain, support_mask) + global_stain = cv2.morphologyEx(global_stain, cv2.MORPH_CLOSE, close_kernel, iterations=1) + + global_stain = self._classify_stain_components( + global_stain, + thumbnail_rgb=thumbnail_rgb, + tissue_mask=tissue_mask, + color_seed=color_seed, + note_seed=note_seed, + residual_seed=residual_seed, + s=s, + v=v, + he_sum=he_sum, + residual=residual, + residual_ratio=residual_ratio, + texture=texture, + residual_thr=residual_thr, + ratio_thr=ratio_thr, + ) + global_stain = cv2.morphologyEx(global_stain, cv2.MORPH_CLOSE, close_kernel, iterations=1) + global_stain = self._filter_small_compact_components( + global_stain, + area_max=self.cfg.global_stain_compact_area_max, + keep_min_aspect=self.cfg.global_stain_compact_keep_min_aspect, + ) + global_stain = self._filter_tissue_like_purple_components( + global_stain, + tissue_mask=tissue_mask, + h=h, + s=s, + v=v, + ) + return self._filter_small_components(global_stain, self.cfg.global_stain_min_area) + + def detect(self, thumbnail_rgb: np.ndarray) -> DetectionResult: + if thumbnail_rgb is None or thumbnail_rgb.size == 0: + raise ValueError("thumbnail_rgb 为空") + if thumbnail_rgb.ndim != 3 or thumbnail_rgb.shape[2] != 3: + raise ValueError(f"thumbnail_rgb 必须为 HxWx3 RGB,当前: {thumbnail_rgb.shape}") + + hsv = self._to_hsv(thumbnail_rgb) + h, s, v = cv2.split(hsv) + stain_features = self._compute_stain_features(thumbnail_rgb) + + # --- note mask (pen mark / 极暗伪影),先做全图检测,后面会只保留轮廓内的 --- + note_raw = ( + (v < self.cfg.note_val_max) + & ((s < self.cfg.note_sat_max) | (v < self.cfg.note_val_strict)) + ).astype(np.uint8) * 255 + + # --- tissue --- + tissue = self._detect_tissue( + h=h, + s=s, + v=v, + note_raw=note_raw, + he_sum=stain_features["he_sum"], + od_sum=stain_features["od_sum"], + ) + + # 先得到组织轮廓,再在“轮廓内部”做笔迹识别,轮廓外的笔迹忽略 + tissue_contours = self._find_contours(tissue) + mask_inside_tissue = np.zeros_like(tissue) + if tissue_contours: + cv2.drawContours(mask_inside_tissue, tissue_contours, -1, 255, thickness=cv2.FILLED) + + # --- note --- + note = self._detect_note( + thumbnail_rgb=thumbnail_rgb, + h=h, + s=s, + v=v, + tissue=tissue, + mask_inside_tissue=mask_inside_tissue, + note_raw=note_raw, + ) + + # --- artifact --- + artifact = self._detect_artifact( + thumbnail_rgb=thumbnail_rgb, + h=h, + s=s, + v=v, + mask_inside_tissue=mask_inside_tissue, + note=note, + ) + + global_stain = self._detect_global_stain( + thumbnail_rgb=thumbnail_rgb, + h=h, + s=s, + v=v, + note_raw=note_raw, + tissue_mask=tissue, + note=note, + he_sum=stain_features["he_sum"], + residual=stain_features["residual"], + residual_ratio=stain_features["residual_ratio"], + texture=stain_features["texture"], + ) + + # --- bubble mask (optional; 默认关闭避免误杀) --- + bubble = np.zeros_like(tissue) + if self.cfg.enable_bubble: + # very lightweight heuristic: bright & low saturation small blobs + bubble = ((s < 10) & (v > 230)).astype(np.uint8) * 255 + open_k = self._odd(max(3, int(self.cfg.tissue_open_kernel))) + open_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (open_k, open_k)) + bubble = cv2.morphologyEx(bubble, cv2.MORPH_OPEN, open_kernel, iterations=1) + bubble = self._filter_small_components(bubble, self.cfg.bubble_min_area) + + contours = { + "tissue": tissue_contours, + "note": self._find_contours(note), + "artifact": self._find_contours(artifact), + "bubble": self._find_contours(bubble) if self.cfg.enable_bubble else [], + "global_stain": self._find_contours(global_stain), + } + + return DetectionResult( + tissue_mask=tissue, + bubble_mask=bubble, + note_mask=note, + artifact_mask=artifact, + global_stain_mask=global_stain, + contours=contours, + ) + + def build_detection_from_external_masks( + self, + thumbnail_rgb: np.ndarray, + tissue_mask: np.ndarray, + note_mask: np.ndarray | None = None, + global_stain_mask: np.ndarray | None = None, + ) -> DetectionResult: + if thumbnail_rgb is None or thumbnail_rgb.size == 0: + raise ValueError("thumbnail_rgb 为空") + if thumbnail_rgb.ndim != 3 or thumbnail_rgb.shape[2] != 3: + raise ValueError(f"thumbnail_rgb 必须为 HxWx3 RGB,当前: {thumbnail_rgb.shape}") + + tissue = ((np.asarray(tissue_mask) > 0).astype(np.uint8) * 255) + if tissue.shape[:2] != thumbnail_rgb.shape[:2]: + raise ValueError( + "tissue_mask shape must match thumbnail spatial shape, " + f"got {tissue.shape[:2]} vs {thumbnail_rgb.shape[:2]}" + ) + + note_full = np.zeros_like(tissue) + if note_mask is not None: + note_full = ((np.asarray(note_mask) > 0).astype(np.uint8) * 255) + if note_full.shape[:2] != thumbnail_rgb.shape[:2]: + raise ValueError( + "note_mask shape must match thumbnail spatial shape, " + f"got {note_full.shape[:2]} vs {thumbnail_rgb.shape[:2]}" + ) + + global_stain = note_full.copy() + if global_stain_mask is not None: + global_stain = ((np.asarray(global_stain_mask) > 0).astype(np.uint8) * 255) + if global_stain.shape[:2] != thumbnail_rgb.shape[:2]: + raise ValueError( + "global_stain_mask shape must match thumbnail spatial shape, " + f"got {global_stain.shape[:2]} vs {thumbnail_rgb.shape[:2]}" + ) + + hsv = self._to_hsv(thumbnail_rgb) + h, s, v = cv2.split(hsv) + + tissue_contours = self._find_contours(tissue) + mask_inside_tissue = np.zeros_like(tissue) + if tissue_contours: + cv2.drawContours(mask_inside_tissue, tissue_contours, -1, 255, thickness=cv2.FILLED) + + note_inside_tissue = cv2.bitwise_and(note_full, mask_inside_tissue) + artifact = self._detect_artifact( + thumbnail_rgb=thumbnail_rgb, + h=h, + s=s, + v=v, + mask_inside_tissue=mask_inside_tissue, + note=note_inside_tissue, + ) + + bubble = np.zeros_like(tissue) + if self.cfg.enable_bubble: + bubble = ((s < 10) & (v > 230)).astype(np.uint8) * 255 + open_k = self._odd(max(3, int(self.cfg.tissue_open_kernel))) + open_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (open_k, open_k)) + bubble = cv2.morphologyEx(bubble, cv2.MORPH_OPEN, open_kernel, iterations=1) + bubble = self._filter_small_components(bubble, self.cfg.bubble_min_area) + + contours = { + "tissue": tissue_contours, + "note": self._find_contours(note_inside_tissue), + "artifact": self._find_contours(artifact), + "bubble": self._find_contours(bubble) if self.cfg.enable_bubble else [], + "global_stain": self._find_contours(global_stain), + } + + return DetectionResult( + tissue_mask=tissue, + bubble_mask=bubble, + note_mask=note_inside_tissue, + artifact_mask=artifact, + global_stain_mask=global_stain, + contours=contours, + ) + + def _find_contours(self, mask: np.ndarray) -> List[np.ndarray]: + if mask is None: + return [] + # 使用 CHAIN_APPROX_NONE 保留所有边界点,使轮廓更加“贴边”和细致, + # 代价是坐标点数量和 JSON 体积会略有增大。 + cnts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + return cnts + + def _filter_small_components(self, mask: np.ndarray, min_area: int) -> np.ndarray: + num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) + out = np.zeros_like(mask) + for i in range(1, num_labels): + area = stats[i, cv2.CC_STAT_AREA] + if area >= min_area: + out[labels == i] = 255 + return out + diff --git a/runtime/ops/mapper/wsi_enhance_operator/wsi_reader/__init__.py b/runtime/ops/mapper/wsi_enhance_operator/wsi_reader/__init__.py new file mode 100644 index 000000000..ea7fe3ad7 --- /dev/null +++ b/runtime/ops/mapper/wsi_enhance_operator/wsi_reader/__init__.py @@ -0,0 +1,5 @@ +from .wsi_reader import WSIReader + +__all__ = ["WSIReader"] + +# WSI Reader 模块初始化 \ No newline at end of file diff --git a/runtime/ops/mapper/wsi_enhance_operator/wsi_reader/wsi_reader.py b/runtime/ops/mapper/wsi_enhance_operator/wsi_reader/wsi_reader.py new file mode 100644 index 000000000..f7e4f77ef --- /dev/null +++ b/runtime/ops/mapper/wsi_enhance_operator/wsi_reader/wsi_reader.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +import ctypes +import os +from typing import Tuple + +import numpy as np + +OPENSLIDE_LIBRARY_CANDIDATES = [ + "/models/WSIEnhance/openslide/libopenslide.so.1", + "/models/WSIEnhance/openslide/libopenslide.so", +] + + +def _preload_openslide() -> Exception | None: + for candidate in OPENSLIDE_LIBRARY_CANDIDATES: + if not os.path.exists(candidate): + continue + try: + ctypes.CDLL(candidate, mode=ctypes.RTLD_GLOBAL) + return None + except Exception as exc: # pragma: no cover + last_error = exc + return locals().get("last_error") + + +try: + _OPENSLIDE_PRELOAD_ERROR = _preload_openslide() + import openslide +except Exception as exc: # pragma: no cover + openslide = None + _OPENSLIDE_IMPORT_ERROR = _OPENSLIDE_PRELOAD_ERROR or exc +else: + _OPENSLIDE_IMPORT_ERROR = None + +try: + from PIL import Image +except Exception as exc: # pragma: no cover + Image = None + _PIL_IMPORT_ERROR = exc +else: + _PIL_IMPORT_ERROR = None + + +RASTER_EXTENSIONS = { + ".png", + ".jpg", + ".jpeg", + ".bmp", + ".tif", + ".tiff", + ".webp", +} + + +class WSIReader: + def __init__(self, file_path: str): + if not os.path.exists(file_path): + raise FileNotFoundError(f"Input image not found: {file_path}") + + self.file_path = file_path + self._slide = None + self._image = None + self._mode = "openslide" + + suffix = os.path.splitext(file_path)[1].lower() + if openslide is not None: + self._slide = openslide.OpenSlide(file_path) + return + + if suffix in RASTER_EXTENSIONS and Image is not None: + self._mode = "raster" + self._image = np.array(Image.open(file_path).convert("RGB"), dtype=np.uint8) + return + + raise ImportError(self._build_import_error_message(file_path)) + + @staticmethod + def _build_import_error_message(file_path: str) -> str: + parts = [ + f"Unable to open WSI file: {file_path}", + "OpenSlide shared library is not available in the runtime.", + ] + if _OPENSLIDE_IMPORT_ERROR is not None: + parts.append(f"OpenSlide import error: {_OPENSLIDE_IMPORT_ERROR}") + if Image is None and _PIL_IMPORT_ERROR is not None: + parts.append(f"Pillow import error: {_PIL_IMPORT_ERROR}") + parts.append( + "Install the OpenSlide shared library for true WSI formats " + "such as .svs/.ndpi, or use a standard raster image for fallback testing." + ) + return " ".join(parts) + + @property + def dimensions(self) -> Tuple[int, int]: + if self._mode == "raster": + height, width = self._image.shape[:2] + return (int(width), int(height)) + return self._slide.dimensions + + @property + def width(self) -> int: + return int(self.dimensions[0]) + + @property + def height(self) -> int: + return int(self.dimensions[1]) + + @property + def level_count(self) -> int: + return 1 if self._mode == "raster" else int(self._slide.level_count) + + def get_thumbnail(self, max_size: Tuple[int, int] = (2048, 2048)) -> np.ndarray: + if self._mode == "raster": + image = Image.fromarray(self._image) + image.thumbnail(max_size) + return np.array(image, dtype=np.uint8) + + thumb = self._slide.get_thumbnail(max_size) + arr = np.array(thumb) + if arr.ndim == 2: + arr = np.stack([arr, arr, arr], axis=-1) + if arr.shape[-1] == 4: + arr = arr[:, :, :3] + return arr.astype(np.uint8, copy=False) + + def read_region(self, x: int, y: int, width: int, height: int, level: int = 0) -> np.ndarray: + if self._mode == "raster": + x0 = max(0, int(x)) + y0 = max(0, int(y)) + x1 = min(self.width, x0 + int(width)) + y1 = min(self.height, y0 + int(height)) + out = np.full((int(height), int(width), 3), 255, dtype=np.uint8) + crop = self._image[y0:y1, x0:x1] + if crop.size: + out[0 : crop.shape[0], 0 : crop.shape[1]] = crop + return out + + region = self._slide.read_region((int(x), int(y)), int(level), (int(width), int(height))) + arr = np.array(region) + if arr.shape[-1] == 4: + arr = arr[:, :, :3] + return arr.astype(np.uint8, copy=False) + + def close(self) -> None: + if self._slide is not None: + self._slide.close() + self._slide = None + + def __enter__(self) -> "WSIReader": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self.close() diff --git a/runtime/ops/mapper/wsi_enhance_operator/wsi_reader/wsi_types.py b/runtime/ops/mapper/wsi_enhance_operator/wsi_reader/wsi_types.py new file mode 100644 index 000000000..46c903a2e --- /dev/null +++ b/runtime/ops/mapper/wsi_enhance_operator/wsi_reader/wsi_types.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple + +import numpy as np + + +class WSIFormat: + OPENS = "openslide" + + +@dataclass +class WSIReaderConfig: + format: str = WSIFormat.OPENS + cache_size: int = 1024 + enable_lazy_loading: bool = True + max_workers: int = 4 + + +@dataclass +class SlideInfo: + name: str + path: str + width: int + height: int + mpp_x: float = 0.0 + mpp_y: float = 0.0 + levels: int = 1 + format: str = WSIFormat.OPENS + vendor: Optional[str] = None + magnification: Optional[float] = None + tile_count: Optional[int] = None + file_size_bytes: Optional[int] = None + + def get_dimensions(self) -> Tuple[int, int]: + return self.width, self.height + + +@dataclass +class PatchInfo: + x: int + y: int + width: int + height: int + level: int + data: np.ndarray + source: str = "level_0" + + def get_position(self) -> Tuple[int, int]: + return self.x, self.y + + def get_size(self) -> Tuple[int, int]: + return self.width, self.height + + +@dataclass(frozen=True) +class Coordinate: + x: int + y: int