diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index f2ddb7f84..a5ba00e61 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -28,7 +28,10 @@ DaskDelayedJSONStore, MultiTaskSegmentor, _clear_zarr, + _crop_halo_post_process_output, + _get_postproc_tile_read_bounds, _get_sel_indices_margin_lines, + _normalise_postproc_halo, _post_save_json_store, _process_instance_predictions, _save_multitask_vertical_to_cache, @@ -889,9 +892,10 @@ class FakeVM: ) # --- Call function --- - new_zarr, new_da = _save_multitask_vertical_to_cache( + new_zarr, new_da, zarr_group = _save_multitask_vertical_to_cache( probabilities_zarr=probabilities_zarr, probabilities_da=probabilities_da, + zarr_group=None, probabilities=probabilities, idx=idx, tqdm_loop=tqdm_loop, @@ -905,11 +909,44 @@ class FakeVM: # new_zarr must be a real zarr array assert isinstance(new_zarr[idx], zarr.Array) + assert zarr_group is not None # Data was written correctly assert np.array_equal(new_zarr[idx][:], np.array([[1, 2, 3]])) +def test_multitask_vertical_merge_continues_after_zarr_spill( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Test multitask vertical merge appends all chunks after spilling to Zarr.""" + + class FakeVM: + """Fake psutil.virtual_memory() with extremely low available memory.""" + + available = 1 + + monkeypatch.setattr(psutil, "virtual_memory", FakeVM) + + values = np.arange(8 * 3, dtype=np.float32).reshape(8, 3, 1) + canvas = [da.from_array(values, chunks=(2, 3, 1))] + count = [da.from_array(np.ones_like(values), chunks=(2, 3, 1))] + output_locs_y = np.array([[0, 2], [2, 4], [4, 6], [6, 8]]) + + result = merge_multitask_vertical_chunkwise( + canvas=canvas, + count=count, + output_locs_y_=output_locs_y, + zarr_group=None, + save_path=tmp_path / "vertical.zarr", + memory_threshold=0, + output_shape=(8, 3), + verbose=False, + ) + + assert result[0].shape == values.shape + assert np.array_equal(result[0].compute(), values) + + def test_qupath_feature_class_dict_lookup_fails() -> None: """Test qupath_feature_class_dict lookup fails.""" qupath_json = DaskDelayedJSONStore.__new__(DaskDelayedJSONStore) @@ -1079,6 +1116,215 @@ def test_get_tile_info_small_image_triggers_early_return( assert np.all(flag == 0) +def test_postproc_halo_bounds_and_output_crop() -> None: + """Test halo-expanded tile output is cropped and shifted to core space.""" + halo_xy = _normalise_postproc_halo((3, 2)) + assert np.array_equal(halo_xy, np.array([2, 3])) + + read_bounds = _get_postproc_tile_read_bounds( + tile_bounds=(4, 5, 10, 11), + postproc_halo_xy=halo_xy, + image_shape=(12, 13), + ) + assert read_bounds == (2, 2, 12, 13) + + predictions = np.arange(11 * 10).reshape(11, 10) + info_dict = { + "box": np.array( + [ + [2, 3, 4, 5], + [5, 6, 7, 8], + [9, 6, 11, 8], + ], + dtype=np.int32, + ), + "centroid": np.array( + [ + [3, 4], + [6, 7], + [10, 7], + ], + dtype=np.float32, + ), + "contours": np.array( + [ + [[2, 3], [4, 3], [4, 5], [2, 5]], + [[5, 6], [7, 6], [7, 8], [5, 8]], + [[9, 6], [11, 6], [11, 8], [9, 8]], + ], + dtype=np.int32, + ), + "type": np.array([1, 2, 3], dtype=np.int32), + } + + cropped = _crop_halo_post_process_output( + post_process_output=( + { + "task_type": "gland", + "seg_type": "instance", + "predictions": predictions, + "info_dict": info_dict, + }, + ), + tile_bounds=(4, 5, 10, 11), + tile_read_bounds=read_bounds, + )[0] + + assert np.array_equal(cropped["predictions"], predictions[3:9, 2:8]) + assert np.array_equal(cropped["info_dict"]["type"], np.array([1, 2])) + assert np.array_equal( + cropped["info_dict"]["box"], + np.array([[0, 0, 2, 2], [3, 3, 5, 5]], dtype=np.int32), + ) + assert np.array_equal( + cropped["info_dict"]["centroid"], + np.array([[1, 1], [4, 4]], dtype=np.float32), + ) + assert np.array_equal( + cropped["info_dict"]["contours"][0], + np.array([[0, 0], [2, 0], [2, 2], [0, 2]], dtype=np.int32), + ) + + +def test_postproc_halo_ownership_without_centroids() -> None: + """Test halo ownership falls back to boxes and padded contours.""" + read_bounds = (2, 2, 12, 13) + predictions = np.arange(11 * 10).reshape(11, 10) + + box_cropped = _crop_halo_post_process_output( + post_process_output=( + { + "task_type": "gland", + "seg_type": "instance", + "predictions": predictions, + "info_dict": { + "box": np.array( + [ + [2, 3, 4, 5], + [9, 6, 11, 8], + ], + dtype=np.int32, + ), + "type": np.array([1, 2], dtype=np.int32), + }, + }, + ), + tile_bounds=(4, 5, 10, 11), + tile_read_bounds=read_bounds, + )[0] + assert np.array_equal( + box_cropped["info_dict"]["box"], + np.array([[0, 0, 2, 2]], dtype=np.int32), + ) + assert np.array_equal(box_cropped["info_dict"]["type"], np.array([1])) + + pad_value = np.iinfo(np.int32).min + contour_cropped = _crop_halo_post_process_output( + post_process_output=( + { + "task_type": "gland", + "seg_type": "instance", + "predictions": predictions, + "info_dict": { + "contours": np.array( + [ + [[2, 3], [4, 3], [4, 5], [2, 5]], + [[9, 6], [11, 6], [11, 8], [9, 8]], + [[pad_value, pad_value]] * 4, + ], + dtype=np.int32, + ), + "type": np.array([1, 2, 3], dtype=np.int32), + }, + }, + ), + tile_bounds=(4, 5, 10, 11), + tile_read_bounds=read_bounds, + )[0] + assert np.array_equal(contour_cropped["info_dict"]["type"], np.array([1])) + assert np.array_equal( + contour_cropped["info_dict"]["contours"][0], + np.array([[0, 0], [2, 0], [2, 2], [0, 2]], dtype=np.int32), + ) + + +def test_process_tile_mode_uses_postproc_halo( + track_tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test tile mode expands reads and crops outputs when halo is set.""" + seg = MultiTaskSegmentor.__new__(MultiTaskSegmentor) + seg.verbose = False + seg.num_workers = 1 + seg.mask_bounds = (0, 0, 10, 10) + seg.mask_padding = (0, 0, 0, 0) + seg.dataloader = SimpleNamespace( + dataset=SimpleNamespace( + reader=SimpleNamespace(slide_dimensions=lambda **_: (10, 10)), + ), + ) + seg._ioconfig = SimpleNamespace( + highest_input_resolution={}, + tile_shape=(4, 4), + to_baseline=lambda: SimpleNamespace(margin=0), + ) + + tile_info_sets = [ + [ + np.array([[2, 2, 6, 6]], dtype=np.int32), + np.array([[1, 1, 1, 1]], dtype=np.int32), + ], + [ + np.array([[6, 6, 10, 10]], dtype=np.int32), + np.array([[1, 1, 1, 1]], dtype=np.int32), + ], + ] + seg._get_tile_info = lambda **_: tile_info_sets + recorded_bounds = [] + expanded_predictions = np.arange(64, dtype=np.uint8).reshape(8, 8) + + def _compute_tile(tile_bounds: tuple[int, int, int, int]) -> tuple[dict]: + """Return one halo-expanded post-processing output.""" + recorded_bounds.append(tile_bounds) + return ( + { + "task_type": "instance", + "seg_type": "instance", + "predictions": expanded_predictions, + "info_dict": { + "box": np.empty((0, 4), dtype=np.int32), + "centroid": np.empty((0, 2), dtype=np.float32), + "contours": np.empty((0, 0, 2), dtype=np.int32), + "prob": np.empty((0,), dtype=np.float32), + "type": np.empty((0,), dtype=np.int32), + }, + }, + ) + + seg._compute_tile = _compute_tile + monkeypatch.setattr( + "tiatoolbox.models.engine.multi_task_segmentor.tqdm_dask_progress_bar", + lambda **kwargs: kwargs["write_tasks"], + ) + + output = seg._process_tile_mode( + probabilities=[da.zeros((10, 10, 1), chunks=(10, 10, 1))], + save_path=track_tmp_path / "halo.zarr", + memory_threshold=100, + return_predictions=(True,), + postproc_halo=2, + ) + + assert recorded_bounds == [(0, 0, 8, 8)] + assert len(output) == 1 + predictions = output[0]["predictions"] + assert np.array_equal(predictions[2:6, 2:6], expanded_predictions[2:6, 2:6]) + assert np.count_nonzero(predictions[:2, :]) == 0 + assert np.count_nonzero(predictions[:, :2]) == 0 + assert np.count_nonzero(predictions[6:, :]) == 0 + assert np.count_nonzero(predictions[:, 6:]) == 0 + + class FakeSeg(MultiTaskSegmentor): """Minimal subclass that allows us to override internals cleanly.""" @@ -1166,6 +1412,7 @@ def fake_store_probabilities( *_: Any, # noqa: ANN401 **__: Any, # noqa: ANN401 ) -> tuple[zarr.Array | None, da.Array | None]: + """Record unexpected probability-store calls during merge tests.""" nonlocal called_store called_store = True return None, None @@ -1612,6 +1859,7 @@ def test_post_save_json_store_deletes_empty_store( # ---- Proxy object that LOOKS like a zarr.Group ---- class GroupProxy: def __init__(self: GroupProxy, group: zarr.Group, path: Path | str) -> None: + """Wrap a Zarr group with a path used by cleanup code.""" self._group = group self.path = path self.store = group.store @@ -1619,19 +1867,23 @@ def __init__(self: GroupProxy, group: zarr.Group, path: Path | str) -> None: # Make isinstance(proxy, zarr.Group) return True @property def __class__(self: GroupProxy) -> type[zarr.Group]: + """Expose the wrapped object as a Zarr group for isinstance.""" return zarr.Group # Delegate attribute access def __getattr__( self: GroupProxy, item: str ) -> zarr.Group | zarr.Array | str | int | float | Iterable[str]: + """Delegate unknown attributes to the wrapped Zarr group.""" return getattr(self._group, item) # Delegate mapping behavior def keys(self: GroupProxy) -> Iterable[str]: + """Return keys from the wrapped Zarr group.""" return self._group.keys() def __getitem__(self: GroupProxy, item: str) -> zarr.Group | zarr.Array: + """Return an item from the wrapped Zarr group.""" return self._group[item] processed_predictions = GroupProxy(root, "dummy") @@ -1640,6 +1892,7 @@ def __getitem__(self: GroupProxy, item: str) -> zarr.Group | zarr.Array: called = {"flag": False} def fake_rmtree(path: Path | str, *, ignore_errors: bool) -> None: # noqa: ARG001 + """Record that cleanup attempted to remove an empty Zarr store.""" called["flag"] = True monkeypatch.setattr(shutil, "rmtree", fake_rmtree) @@ -1723,6 +1976,7 @@ def fake_save_qupath_json( save_path: Path | None, # noqa: ARG001 qupath_json: dict[str, Any], ) -> dict[str, Any]: + """Return generated QuPath JSON instead of writing it to disk.""" return qupath_json monkeypatch.setattr( @@ -1759,6 +2013,7 @@ def _build_single_qupath_feature( scale_factor: tuple[float, float], class_colors: dict[int, Any], ) -> dict[str, Any]: + """Delegate feature construction to the production JSON store.""" return DaskDelayedJSONStore._build_single_qupath_feature( self, i, class_dict, origin, scale_factor, class_colors ) diff --git a/tests/models/test_arch_cerberus.py b/tests/models/test_arch_cerberus.py new file mode 100644 index 000000000..f3546d35d --- /dev/null +++ b/tests/models/test_arch_cerberus.py @@ -0,0 +1,312 @@ +"""Unit tests for the Cerberus architecture.""" + +from __future__ import annotations + +import dask.array as da +import numpy as np +import pytest +import torch + +from tiatoolbox.models import Cerberus +from tiatoolbox.models.architecture import get_pretrained_model +from tiatoolbox.models.architecture.cerberus.model import ( + _build_tissue_raw_map, + _crop_center_tensor, + _inst_dict_for_dask_processing, + _pad_contours, +) +from tiatoolbox.models.architecture.cerberus.postproc import ( + PostProcInstErodedContourMap, + get_bounding_box, +) +from tiatoolbox.models.engine.io_config import IOInstanceSegmentorConfig + +PATCH_OUTPUT_SHAPE = (144, 144) +INFER_INPUT_SHAPE = (256, 256) + + +def _module_prefixed_state_dict(model: Cerberus) -> dict[str, torch.Tensor]: + """Return a Cerberus checkpoint state dict saved from DataParallel.""" + return {f"module.{key}": value for key, value in model.state_dict().items()} + + +def test_cerberus_load_weights_from_desc_checkpoint( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test Cerberus checkpoint loading with ``desc`` and ``module.`` prefixes.""" + source_model = Cerberus() + checkpoint = {"desc": _module_prefixed_state_dict(source_model)} + + def _mock_torch_load( + *_args: object, + **_kwargs: object, + ) -> dict[str, dict[str, torch.Tensor]]: + """Return a synthetic Cerberus checkpoint for load-weight tests.""" + return checkpoint + + monkeypatch.setattr(torch, "load", _mock_torch_load) + + model = Cerberus() + model.load_weights_from_file("weights.tar") + + state_key = "backbone.conv1.weight" + assert torch.equal( + model.state_dict()[state_key], + source_model.state_dict()[state_key], + ) + + +def test_cerberus_pretrained_registry(monkeypatch: pytest.MonkeyPatch) -> None: + """Test the Cerberus pretrained registry entry and model IO config.""" + checkpoint = {"desc": _module_prefixed_state_dict(Cerberus())} + + def _mock_torch_load( + *_args: object, + **_kwargs: object, + ) -> dict[str, dict[str, torch.Tensor]]: + """Return a synthetic Cerberus checkpoint for registry loading.""" + return checkpoint + + monkeypatch.setattr(torch, "load", _mock_torch_load) + + model, ioconfig = get_pretrained_model( + "cerberus-resnet34", + pretrained_weights="weights.tar", + ) + + assert isinstance(model, Cerberus) + assert isinstance(ioconfig, IOInstanceSegmentorConfig) + assert tuple(ioconfig.patch_input_shape) == (448, 448) + assert tuple(ioconfig.patch_output_shape) == PATCH_OUTPUT_SHAPE + assert tuple(ioconfig.stride_shape) == PATCH_OUTPUT_SHAPE + assert len(ioconfig.output_resolutions) == len(Cerberus.head_names) + + +def test_cerberus_infer_batch_output_shapes() -> None: + """Test Cerberus inference output order and shape.""" + model = Cerberus() + batch = torch.zeros((1, *INFER_INPUT_SHAPE, 3), dtype=torch.uint8) + + outputs = model.infer_batch(model, batch, device="cpu") + + assert len(outputs) == len(Cerberus.head_names) + expected_shapes = ( + (1, *PATCH_OUTPUT_SHAPE, 2), + (1, *PATCH_OUTPUT_SHAPE, 1), + (1, *PATCH_OUTPUT_SHAPE, 2), + (1, *PATCH_OUTPUT_SHAPE, 1), + (1, *PATCH_OUTPUT_SHAPE, 2), + (1, *PATCH_OUTPUT_SHAPE, 1), + ) + for output, expected_shape in zip(outputs, expected_shapes, strict=True): + assert output.shape == expected_shape + assert output.dtype == np.float32 + + +def test_cerberus_postproc_empty_maps() -> None: + """Test Cerberus post-processing output structure for empty predictions.""" + raw_maps = [ + np.zeros((*PATCH_OUTPUT_SHAPE, 2), dtype=np.float32), + np.zeros((*PATCH_OUTPUT_SHAPE, 1), dtype=np.float32), + np.zeros((*PATCH_OUTPUT_SHAPE, 2), dtype=np.float32), + np.zeros((*PATCH_OUTPUT_SHAPE, 1), dtype=np.float32), + np.zeros((*PATCH_OUTPUT_SHAPE, 2), dtype=np.float32), + np.zeros((*PATCH_OUTPUT_SHAPE, 1), dtype=np.float32), + ] + + outputs = Cerberus().postproc(raw_maps, offset=(3, 5)) + + assert [output["task_type"] for output in outputs] == ["nuclei", "gland", "lumen"] + for output in outputs: + assert output["seg_type"] == "instance" + assert output["predictions"].shape == PATCH_OUTPUT_SHAPE + assert output["predictions"].dtype == np.int32 + + info_dict = output["info_dict"] + assert info_dict["box"].shape == (0, 4) + assert info_dict["box"].dtype == np.int32 + assert info_dict["centroid"].shape == (0, 2) + assert info_dict["centroid"].dtype == np.float32 + assert info_dict["contours"].shape == (0, 0, 2) + assert info_dict["contours"].dtype == np.int32 + assert info_dict["prob"].shape == (0,) + assert info_dict["prob"].dtype == np.float32 + assert info_dict["type"].shape == (0,) + assert info_dict["type"].dtype == np.int32 + + +def test_cerberus_postproc_dask_maps_and_lumen_gland_mask( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test Cerberus post-processing Dask output and lumen-in-gland masking.""" + output_shape = (16, 16) + raw_maps = [ + da.from_array( + np.zeros((*output_shape, channels), dtype=np.float32), + chunks=(8, 8, channels), + ) + for channels in (2, 1, 2, 1, 2, 1) + ] + calls = [] + + def _mock_post_process( + raw_map: np.ndarray, + idx_dict: dict[str, list[int]], + tissue_mode: str, + ds_factor: float, + ) -> tuple[np.ndarray, np.ndarray | None]: + """Return deterministic task maps for Cerberus postproc testing.""" + calls.append((tissue_mode, raw_map.shape, idx_dict, ds_factor)) + inst_map = np.zeros(output_shape, dtype=np.int32) + type_map = np.ones(output_shape, dtype=np.uint8) + if tissue_mode == "Nuclei": + inst_map[2:5, 2:5] = 1 + elif tissue_mode == "Gland": + inst_map[1:8, 1:8] = 1 + else: + inst_map[3:6, 3:6] = 1 + inst_map[10:13, 10:13] = 2 + type_map = None + return inst_map, type_map + + def _mock_get_instance_info( + inst_map: np.ndarray, + type_map: np.ndarray | None, + offset: tuple[int, int], + verbose: object, + ) -> dict[int, dict]: + """Return deterministic instance metadata for Cerberus postproc tests.""" + assert offset == (7, 11) + assert verbose is False + type_value = 0 if type_map is None else int(type_map[inst_map > 0][0]) + return { + 1: { + "box": np.array([1, 2, 3, 4], dtype=np.int32), + "centroid": np.array([2.5, 3.5], dtype=np.float32), + "contours": np.array([[1, 2], [3, 4]], dtype=np.int32), + "prob": 0.75, + "type": type_value, + }, + } + + monkeypatch.setattr( + PostProcInstErodedContourMap, + "post_process", + _mock_post_process, + ) + monkeypatch.setattr( + "tiatoolbox.models.architecture.cerberus.model.HoVerNet.get_instance_info", + _mock_get_instance_info, + ) + + outputs = Cerberus().postproc(raw_maps, offset=(7, 11)) + + assert [call[0] for call in calls] == ["Nuclei", "Gland", "Lumen"] + assert calls[0][1:] == ( + (16, 16, 3), + {"Nuclei-INST": [0, 2], "Nuclei-TYPE": [2, 3]}, + 1.0, + ) + assert [output["task_type"] for output in outputs] == ["nuclei", "gland", "lumen"] + lumen_map = outputs[2]["predictions"].compute() + assert np.all(lumen_map[3:6, 3:6] == 1) + assert np.all(lumen_map[10:13, 10:13] == 0) + for output in outputs: + assert isinstance(output["predictions"], da.Array) + assert output["predictions"].dtype == np.int32 + assert output["info_dict"]["box"].compute().dtype == np.int32 + assert output["info_dict"]["centroid"].compute().dtype == np.float32 + assert output["info_dict"]["contours"].compute().shape == (1, 2, 2) + assert output["info_dict"]["prob"].compute().dtype == np.float32 + assert output["info_dict"]["type"].compute().dtype == np.int32 + + +def test_cerberus_model_helpers() -> None: + """Test Cerberus private helper conversions.""" + tissue_map, idx_dict = _build_tissue_raw_map( + { + "Nuclei-INST": np.zeros((4, 5, 2), dtype=np.float32), + "Nuclei-TYPE": np.ones((4, 5), dtype=np.float32), + }, + "Nuclei", + ) + assert tissue_map.shape == (4, 5, 3) + assert idx_dict == {"Nuclei-INST": [0, 2], "Nuclei-TYPE": [2, 3]} + + tensor = torch.arange(1 * 5 * 6 * 1, dtype=torch.float32).reshape(1, 5, 6, 1) + cropped = _crop_center_tensor(tensor, (3, 4)) + assert cropped.shape == (1, 3, 4, 1) + assert torch.equal(cropped, tensor[:, 1:4, 1:5, :]) + + contours = np.array( + [ + np.array([[1, 2], [3, 4]], dtype=np.int32), + np.array([[5, 6]], dtype=np.int32), + ], + dtype=object, + ) + padded = _pad_contours(contours) + assert padded.shape == (2, 2, 2) + assert np.array_equal(padded[1, 0], [5, 6]) + assert np.array_equal(padded[1, 1], [np.iinfo(np.int32).min] * 2) + + dask_info = _inst_dict_for_dask_processing({}, is_dask=True) + assert dask_info["contours"].compute().shape == (0, 0, 2) + assert dask_info["type"].compute().dtype == np.int32 + + +def test_cerberus_eroded_contour_postproc_non_empty_and_errors() -> None: + """Test non-empty Cerberus contour post-processing and validation errors.""" + nuclei_raw_map = np.zeros((40, 40, 2), dtype=np.float32) + nuclei_raw_map[6:18, 6:18, 0] = 0.9 + nuclei_raw_map[22:34, 22:34, 0] = 0.9 + nuclei_inst_map, nuclei_type_map = PostProcInstErodedContourMap.post_process( + raw_map=nuclei_raw_map, + idx_dict={"Nuclei-INST": [0, 2]}, + tissue_mode="Nuclei", + ) + assert nuclei_inst_map.shape == (40, 40) + assert nuclei_inst_map.max() == 2 + assert get_bounding_box(nuclei_inst_map > 0) == (7, 33, 7, 33) + assert nuclei_type_map is None + + gland_raw_map = np.zeros((80, 80, 3), dtype=np.float32) + gland_raw_map[10:60, 10:60, 0] = 0.9 + gland_raw_map[..., 2] = 2 + + inst_map, type_map = PostProcInstErodedContourMap.post_process( + raw_map=gland_raw_map, + idx_dict={"Gland-INST": [0, 2], "Gland-TYPE": [2, 3]}, + tissue_mode="Gland", + ) + + assert inst_map.shape == (80, 80) + assert inst_map.max() == 1 + assert type_map is not None + assert type_map.shape == (80, 80) + assert np.all(type_map == 2) + assert get_bounding_box(inst_map > 0) == (6, 65, 6, 65) + + lumen_raw_map = np.zeros((40, 40, 2), dtype=np.float32) + lumen_raw_map[8:25, 8:25, 0] = 0.9 + lumen_inst_map, lumen_type_map = PostProcInstErodedContourMap.post_process( + raw_map=lumen_raw_map, + idx_dict={"Lumen-INST": [0, 2]}, + tissue_mode="Lumen", + ) + assert lumen_inst_map.max() == 1 + assert lumen_type_map is None + + with pytest.raises(ValueError, match="Unsupported Cerberus tissue mode"): + PostProcInstErodedContourMap.post_process( + raw_map=lumen_raw_map, + idx_dict={"Lumen-INST": [0, 2]}, + tissue_mode="Stroma", + ) + + with pytest.raises(KeyError, match="Missing required Cerberus map"): + PostProcInstErodedContourMap.post_process( + raw_map=lumen_raw_map, + idx_dict={}, + tissue_mode="Lumen", + ) diff --git a/tests/test_annotation_utils.py b/tests/test_annotation_utils.py new file mode 100644 index 000000000..ec74a3041 --- /dev/null +++ b/tests/test_annotation_utils.py @@ -0,0 +1,92 @@ +"""Tests for annotation utility helpers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from shapely.geometry import Point + +from tiatoolbox.annotation.storage import Annotation, SQLiteStore +from tiatoolbox.annotation.utils import combine_annotation_stores + +if TYPE_CHECKING: + from pathlib import Path + + +def _write_store( + path: Path, + annotation: Annotation, + key: str, +) -> None: + """Write a one-annotation SQLite store.""" + store = SQLiteStore(path) + store.append_many([annotation], keys=[key]) + store.close() + + +def test_combine_annotation_stores_preserves_annotations_and_labels( + track_tmp_path: Path, +) -> None: + """Test combining SQLite stores with explicit source labels.""" + store_a_path = track_tmp_path / "store-a.db" + store_b_path = track_tmp_path / "store-b.db" + output_path = track_tmp_path / "combined.db" + _write_store( + store_a_path, + Annotation(Point(1, 2), {"class": 1}), + "ann-a", + ) + _write_store( + store_b_path, + Annotation(Point(3, 4), {"class": 2}), + "ann-b", + ) + + result_path = combine_annotation_stores( + [store_a_path, store_b_path], + output_path, + labels={store_a_path: "alpha", store_b_path.resolve(): "beta"}, + label_property="dataset", + ) + + assert result_path == output_path + combined_store = SQLiteStore(output_path) + assert set(combined_store.keys()) == {"alpha:ann-a", "beta:ann-b"} + assert combined_store["alpha:ann-a"].geometry == Point(1, 2) + assert combined_store["alpha:ann-a"].properties == { + "class": 1, + "dataset": "alpha", + } + assert combined_store["beta:ann-b"].geometry == Point(3, 4) + assert combined_store["beta:ann-b"].properties == { + "class": 2, + "dataset": "beta", + } + combined_store.close() + + +def test_combine_annotation_stores_defaults_to_stems_and_checks_output( + track_tmp_path: Path, +) -> None: + """Test default labels, overwrite protection, and empty input validation.""" + source_path = track_tmp_path / "source.db" + output_path = track_tmp_path / "combined.db" + _write_store(source_path, Annotation(Point(5, 6), {"score": 0.5}), "ann") + + combine_annotation_stores([source_path], output_path) + combined_store = SQLiteStore(output_path) + assert set(combined_store.keys()) == {"source:ann"} + assert combined_store["source:ann"].properties == { + "score": 0.5, + "source": "source", + } + combined_store.close() + + with pytest.raises(FileExistsError, match="already exists"): + combine_annotation_stores([source_path], output_path) + + combine_annotation_stores([source_path], output_path, overwrite=True) + + with pytest.raises(ValueError, match="At least one"): + combine_annotation_stores([], output_path, overwrite=True) diff --git a/tiatoolbox/annotation/__init__.py b/tiatoolbox/annotation/__init__.py index 99dfa07ec..de8af1672 100644 --- a/tiatoolbox/annotation/__init__.py +++ b/tiatoolbox/annotation/__init__.py @@ -7,5 +7,12 @@ DictionaryStore, SQLiteStore, ) +from tiatoolbox.annotation.utils import combine_annotation_stores -__all__ = ["Annotation", "AnnotationStore", "DictionaryStore", "SQLiteStore"] +__all__ = [ + "Annotation", + "AnnotationStore", + "DictionaryStore", + "SQLiteStore", + "combine_annotation_stores", +] diff --git a/tiatoolbox/annotation/utils.py b/tiatoolbox/annotation/utils.py new file mode 100644 index 000000000..821ad2bac --- /dev/null +++ b/tiatoolbox/annotation/utils.py @@ -0,0 +1,96 @@ +"""Utilities for working with annotation stores.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +from tiatoolbox.annotation.storage import Annotation, SQLiteStore + +if TYPE_CHECKING: # pragma: no cover + from collections.abc import Iterable, Mapping, Sequence + + +def combine_annotation_stores( + input_paths: Sequence[str | Path], + output_path: str | Path, + labels: Mapping[str | Path, str] | None = None, + *, + label_property: str = "source", + overwrite: bool = False, +) -> Path: + """Combine multiple SQLite annotation stores into one store. + + Args: + input_paths: + Paths to SQLite-backed ``.db`` annotation stores. + output_path: + Path to write the combined ``.db`` annotation store. + labels: + Optional mapping from input path to a label to write into each + annotation's properties under ``label_property``. If omitted, each + source store's filename stem is used. + label_property: + Name of the property used to record the source label. + overwrite: + Whether to replace an existing output store. + + Returns: + Path: + Path to the combined annotation store. + + """ + input_path_objs = [Path(path) for path in input_paths] + if len(input_path_objs) == 0: + msg = "At least one input annotation store path is required." + raise ValueError(msg) + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + if output_path.exists(): + if not overwrite: + msg = f"Output annotation store already exists: {output_path}" + raise FileExistsError(msg) + output_path.unlink() + + labels_ = _normalise_labels(input_path_objs, labels) + combined_store = SQLiteStore(auto_commit=False) + + for source_path in input_path_objs: + source_store = SQLiteStore.open(source_path) + source_label = labels_[source_path] + annotations = [] + keys = [] + for key, annotation in source_store.items(): + properties = dict(annotation.properties) + properties[label_property] = source_label + annotations.append(Annotation(annotation.geometry, properties)) + keys.append(f"{source_label}:{key}") + if annotations: + combined_store.append_many(annotations, keys) + + combined_store.commit() + combined_store.dump(output_path) + return output_path + + +def _normalise_labels( + input_paths: Iterable[Path], + labels: Mapping[str | Path, str] | None, +) -> dict[Path, str]: + """Normalise optional path labels to resolved ``Path`` keys.""" + input_paths = list(input_paths) + if labels is None: + return {path: path.stem for path in input_paths} + + labels_by_path = {Path(path): label for path, label in labels.items()} + labels_by_resolved_path = { + Path(path).resolve(): label for path, label in labels.items() + } + normalised = {} + for path in input_paths: + normalised[path] = labels_by_path.get( + path, + labels_by_resolved_path.get(path.resolve(), path.stem), + ) + return normalised diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index c8aa06d88..828b28e79 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -671,6 +671,33 @@ hovernet_fast-pannuke: save_resolution: {'units': 'mpp', 'resolution': 0.25} ignore_index: 0 +cerberus-resnet34: + hf_repo_id: TIACentre/TIAToolbox_pretrained_weights + architecture: + class: cerberus.Cerberus + kwargs: + patch_output_shape: [144, 144] + ioconfig: + class: io_config.IOInstanceSegmentorConfig + kwargs: + input_resolutions: + - {"units": "mpp", "resolution": 0.50} + output_resolutions: + - {"units": "mpp", "resolution": 0.50} + - {"units": "mpp", "resolution": 0.50} + - {"units": "mpp", "resolution": 0.50} + - {"units": "mpp", "resolution": 0.50} + - {"units": "mpp", "resolution": 0.50} + - {"units": "mpp", "resolution": 0.50} + margin: 512 + postproc_halo: 512 + tile_shape: [4096, 4096] + patch_input_shape: [448, 448] + patch_output_shape: [144, 144] + stride_shape: [144, 144] + save_resolution: {'units': 'mpp', 'resolution': 0.50} + ignore_index: 0 + hovernet_fast-monusac: hf_repo_id: TIACentre/TIAToolbox_pretrained_weights architecture: diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index 0885c99ad..eff0a7f95 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations from . import architecture, dataset, engine, models_abc +from .architecture.cerberus import Cerberus from .architecture.hovernet import HoVerNet from .architecture.hovernetplus import HoVerNetPlus from .architecture.idars import IDaRS @@ -29,6 +30,7 @@ __all__ = [ "SAM", "SCCNN", + "Cerberus", "DeepFeatureExtractor", "HoVerNet", "HoVerNetPlus", diff --git a/tiatoolbox/models/architecture/cerberus/__init__.py b/tiatoolbox/models/architecture/cerberus/__init__.py new file mode 100644 index 000000000..4247ef05f --- /dev/null +++ b/tiatoolbox/models/architecture/cerberus/__init__.py @@ -0,0 +1,7 @@ +"""Cerberus multi-task segmentation architecture.""" + +from __future__ import annotations + +from .model import Cerberus + +__all__ = ["Cerberus"] diff --git a/tiatoolbox/models/architecture/cerberus/backbone/__init__.py b/tiatoolbox/models/architecture/cerberus/backbone/__init__.py new file mode 100644 index 000000000..df58e4665 --- /dev/null +++ b/tiatoolbox/models/architecture/cerberus/backbone/__init__.py @@ -0,0 +1,5 @@ +"""Backbone used by the released Cerberus checkpoint.""" + +from .resnet import ResNet34, resnet34 + +__all__ = ["ResNet34", "resnet34"] diff --git a/tiatoolbox/models/architecture/cerberus/backbone/resnet.py b/tiatoolbox/models/architecture/cerberus/backbone/resnet.py new file mode 100644 index 000000000..49bf92e6f --- /dev/null +++ b/tiatoolbox/models/architecture/cerberus/backbone/resnet.py @@ -0,0 +1,111 @@ +"""Minimal ResNet-34 feature extractor for the Cerberus checkpoint.""" + +from __future__ import annotations + +import torch +from torch import nn + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """3x3 convolution with padding.""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution.""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + """Basic residual block used by ResNet-34.""" + + expansion = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: nn.Module | None = None, + ) -> None: + """Initialize a ResNet-34 residual block.""" + super().__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the residual block.""" + identity = x + out = self.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + if self.downsample is not None: + identity = self.downsample(x) + return self.relu(out + identity) + + +class ResNet34(nn.Module): + """ResNet-34 variant used by Cerberus. + + The first convolution uses stride 1 and the forward pass returns feature maps + from each encoder stage instead of classifier logits. + """ + + def __init__(self) -> None: + """Initialize the Cerberus ResNet-34 encoder.""" + super().__init__() + self.inplanes = 64 + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(64, 3) + self.layer2 = self._make_layer(128, 4, stride=2) + self.layer3 = self._make_layer(256, 6, stride=2) + self.layer4 = self._make_layer(512, 3, stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512, 1000) + + def _make_layer( + self, + planes: int, + blocks: int, + stride: int = 1, + ) -> nn.Sequential: + """Build one ResNet stage with optional downsampling.""" + downsample = None + if stride != 1 or self.inplanes != planes * BasicBlock.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * BasicBlock.expansion, stride), + nn.BatchNorm2d(planes * BasicBlock.expansion), + ) + + layers = [BasicBlock(self.inplanes, planes, stride, downsample)] + self.inplanes = planes * BasicBlock.expansion + layers.extend(BasicBlock(self.inplanes, planes) for _ in range(1, blocks)) + return nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + """Return feature maps from the encoder pyramid.""" + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(self.maxpool(x0)) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + return [x0, x1, x2, x3, x4] + + +def resnet34() -> ResNet34: + """Build the Cerberus ResNet-34 encoder.""" + return ResNet34() diff --git a/tiatoolbox/models/architecture/cerberus/model.py b/tiatoolbox/models/architecture/cerberus/model.py new file mode 100644 index 000000000..d890365f9 --- /dev/null +++ b/tiatoolbox/models/architecture/cerberus/model.py @@ -0,0 +1,263 @@ +"""TIAToolbox integration wrapper for the Cerberus architecture.""" + +from __future__ import annotations + +from collections import OrderedDict +from typing import TYPE_CHECKING + +import dask.array as da +import numpy as np +import pandas as pd +import torch +from torch import nn +from torch.nn import functional + +from tiatoolbox.models.architecture.hovernet import HoVerNet +from tiatoolbox.models.models_abc import ModelABC + +from .net_desc import NetDesc +from .postproc import PostProcInstErodedContourMap + +SPATIAL_NDIMS = 2 + +if TYPE_CHECKING: # pragma: no cover + from pathlib import Path + + +class Cerberus(ModelABC, NetDesc): + """Cerberus multi-task model for glands, lumen, nuclei, and patch class.""" + + head_names = ( + "Nuclei-INST", + "Nuclei-TYPE", + "Gland-INST", + "Gland-TYPE", + "Lumen-INST", + "Patch-Class", + ) + + def __init__( + self, + patch_output_shape: tuple[int, int] = (144, 144), + nuclei_type_dict: dict | None = None, + gland_type_dict: dict | None = None, + lumen_type_dict: dict | None = None, + ) -> None: + """Initialize the fixed Cerberus ResNet-34 model.""" + ModelABC.__init__(self) + NetDesc.__init__(self) + self.patch_output_shape = tuple(patch_output_shape) + self.tasks = ("nuclei", "gland", "lumen") + self.class_dict = { + "nuclei": nuclei_type_dict + or { + 0: "Background", + 1: "Neutrophil", + 2: "Epithelial", + 3: "Lymphocyte", + 4: "Plasma", + 5: "Eosinophil", + 6: "Connective", + }, + "gland": gland_type_dict + or {0: "Background", 1: "Gland", 2: "Surface Epithelium"}, + "lumen": lumen_type_dict or {0: "Background", 1: "Lumen"}, + } + + def forward( # skipcq: PYL-W0221 + self, imgs: torch.Tensor, train_decoder_list: list[str] | None = None + ) -> OrderedDict: + """Forward pass through the shared encoder and selected Cerberus decoders.""" + return NetDesc.forward(self, imgs, train_decoder_list or []) + + def load_weights_from_file(self, weights: str | Path) -> torch.nn.Module: + """Load Cerberus weights saved as ``weights.tar`` or a plain state dict.""" + state = torch.load(weights, map_location="cpu") + state = state["desc"] if isinstance(state, dict) and "desc" in state else state + state = _strip_dataparallel_prefix(state) + self.load_state_dict(state, strict=True) + return self + + @staticmethod + def infer_batch( + model: nn.Module, batch_data: np.ndarray | torch.Tensor, *, device: str + ) -> tuple[np.ndarray, ...]: + """Run Cerberus inference and return TIAToolbox-compatible head arrays.""" + patch_imgs = batch_data + patch_imgs = patch_imgs.to(device).type(torch.float32) + patch_imgs = patch_imgs.permute(0, 3, 1, 2).contiguous() + + model.eval() + with torch.inference_mode(): + pred_dict = model(patch_imgs) + pred_dict = OrderedDict( + (k, v.permute(0, 2, 3, 1).contiguous()) for k, v in pred_dict.items() + ) + + pred_dict["Nuclei-INST"] = functional.softmax( + pred_dict["Nuclei-INST"], dim=-1 + )[..., 1:] + pred_dict["Gland-INST"] = functional.softmax( + pred_dict["Gland-INST"], dim=-1 + )[..., 1:] + pred_dict["Lumen-INST"] = functional.softmax( + pred_dict["Lumen-INST"], dim=-1 + )[..., 1:] + + for key in ("Nuclei-TYPE", "Gland-TYPE"): + type_map = functional.softmax(pred_dict[key], dim=-1) + pred_dict[key] = torch.argmax(type_map, dim=-1, keepdim=True).type( + torch.float32 + ) + + patch_class = functional.softmax(pred_dict["Patch-Class"], dim=-1) + patch_class = torch.argmax(patch_class, dim=-1, keepdim=True).type( + torch.float32 + ) + model_ = getattr(model, "module", model) + output_shape = tuple(getattr(model_, "patch_output_shape", (144, 144))) + + pred_dict["Patch-Class"] = functional.interpolate( + patch_class.permute(0, 3, 1, 2), + size=output_shape, + mode="nearest", + ).permute(0, 2, 3, 1) + + outputs = [] + for head_name in Cerberus.head_names: + head_output = pred_dict[head_name] + if head_output.shape[1:3] != output_shape: + head_output = _crop_center_tensor(head_output, output_shape) + outputs.append(head_output.cpu().numpy()) + + return tuple(outputs) + + # skipcq: PYL-W0221 # noqa: ERA001 + def postproc( + self, raw_maps: list[np.ndarray | da.Array], offset: tuple[int, int] = (0, 0) + ) -> tuple[dict, ...]: + """Post-process Cerberus heads into annotation-store compatible tasks.""" + is_dask = isinstance(raw_maps[0], da.Array) + maps = [raw_map.compute() if is_dask else raw_map for raw_map in raw_maps] + + head_map = dict(zip(self.head_names, maps, strict=False)) + outputs = [] + gland_inst_map = None + for tissue_name, task_name in ( + ("Nuclei", "nuclei"), + ("Gland", "gland"), + ("Lumen", "lumen"), + ): + raw_map, idx_dict = _build_tissue_raw_map(head_map, tissue_name) + inst_map, type_map = PostProcInstErodedContourMap.post_process( + raw_map=raw_map, + idx_dict=idx_dict, + tissue_mode=tissue_name, + ds_factor=1.0, + ) + if tissue_name == "Gland": + gland_inst_map = inst_map.copy() + if tissue_name == "Lumen" and gland_inst_map is not None: + inst_map = inst_map * (gland_inst_map > 0) + if type_map is not None: + type_map = np.squeeze(type_map).astype("uint8") + + inst_map = inst_map.astype("int32") + inst_info_dict = HoVerNet.get_instance_info( + inst_map, + type_map, + offset=offset, + verbose=False, + ) + info_dict = _inst_dict_for_dask_processing(inst_info_dict, is_dask=is_dask) + outputs.append( + { + "task_type": task_name, + "predictions": da.array(inst_map) if is_dask else inst_map, + "info_dict": info_dict, + "seg_type": "instance", + } + ) + + return tuple(outputs) + + +def _strip_dataparallel_prefix(state: dict) -> dict: + """Remove ``module.`` prefixes from DataParallel checkpoint keys.""" + if all(key.split(".")[0] == "module" for key in state): + return {".".join(key.split(".")[1:]): value for key, value in state.items()} + return state + + +def _crop_center_tensor( + tensor: torch.Tensor, + output_shape: tuple[int, int], +) -> torch.Tensor: + """Crop a BHWC tensor to the requested center output shape.""" + h, w = tensor.shape[1:3] + out_h, out_w = output_shape + top = max((h - out_h) // 2, 0) + left = max((w - out_w) // 2, 0) + return tensor[:, top : top + out_h, left : left + out_w, :] + + +def _build_tissue_raw_map( + head_map: dict[str, np.ndarray], tissue_name: str +) -> tuple[np.ndarray, dict[str, list[int]]]: + """Combine Cerberus heads for one tissue into a raw postproc map.""" + idx_dict = {} + maps = [] + start = 0 + for suffix in ("INST", "TYPE"): + head_name = f"{tissue_name}-{suffix}" + if head_name not in head_map: + continue + tissue_map = head_map[head_name] + if tissue_map.ndim == SPATIAL_NDIMS: + tissue_map = tissue_map[..., None] + maps.append(tissue_map) + stop = start + tissue_map.shape[-1] + idx_dict[head_name] = [start, stop] + start = stop + + return np.concatenate(maps, axis=-1), idx_dict + + +def _inst_dict_for_dask_processing(inst_info_dict: dict, *, is_dask: bool) -> dict: + """Convert instance metadata into arrays with optional Dask wrapping.""" + if not inst_info_dict: + output = { + "box": np.empty((0, 4), dtype=np.int32), + "centroid": np.empty((0, 2), dtype=np.float32), + "contours": np.empty((0, 0, 2), dtype=np.int32), + "prob": np.empty((0,), dtype=np.float32), + "type": np.empty((0,), dtype=np.int32), + } + if is_dask: + return {key: da.from_array(value) for key, value in output.items()} + return output + + inst_info_df = pd.DataFrame(inst_info_dict).transpose() + output = {} + for key, col in inst_info_df.items(): + col_np = col.to_numpy() + if key == "contours": + col_np = _pad_contours(col_np) + elif key in {"box", "type"}: + col_np = np.asarray(col_np.tolist(), dtype=np.int32) + elif key in {"centroid", "prob"}: + col_np = np.asarray(col_np.tolist(), dtype=np.float32) + chunks = (len(col), *col_np.shape[1:]) + output[key] = da.from_array(col_np, chunks=chunks) if is_dask else col_np + return output + + +def _pad_contours(contours: np.ndarray) -> np.ndarray: + """Pad variable-length contours to a rectangular integer array.""" + max_len = max(contour.shape[0] for contour in contours) + pad_value = np.iinfo(np.int32).min + padded = np.full((len(contours), max_len, 2), pad_value, dtype=np.int32) + for idx, contour in enumerate(contours): + contour_ = np.asarray(contour, dtype=np.int32) + padded[idx, : contour_.shape[0], :] = contour_ + return padded diff --git a/tiatoolbox/models/architecture/cerberus/net_desc.py b/tiatoolbox/models/architecture/cerberus/net_desc.py new file mode 100644 index 000000000..7aa015eaf --- /dev/null +++ b/tiatoolbox/models/architecture/cerberus/net_desc.py @@ -0,0 +1,130 @@ +"""Minimal Cerberus network definition for the released ResNet-34 checkpoint.""" + +from __future__ import annotations + +from collections import OrderedDict + +import torch +from torch import nn +from torch.nn import functional + +from .backbone.resnet import resnet34 +from .utils.conv_layers import Conv2d, ConvBlock, PytorchBase + +DECODER_KWARGS = { + "Gland": {"INST": 3}, + "Gland#TYPE": {"TYPE": 3}, + "Lumen": {"INST": 3}, + "Nuclei": {"INST": 3}, + "Nuclei#TYPE": {"TYPE": 7}, + "Patch-Class": {"OUT": 9}, +} + +CONSIDERED_TASKS = { + "Nuclei", + "Nuclei#TYPE", + "Gland", + "Gland#TYPE", + "Lumen", + "Patch-Class", +} + + +def cropping_center(x: torch.Tensor, crop_shape: tuple[int, int]) -> torch.Tensor: + """Crop a batched NCHW tensor at the centre.""" + h0 = int((x.shape[2] - crop_shape[0]) * 0.5) + w0 = int((x.shape[3] - crop_shape[1]) * 0.5) + return x[:, :, h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]] + + +class NetDesc(nn.Module): + """Cerberus model topology used by ``resnet34_cerberus`` weights.""" + + def __init__(self) -> None: + """Initialize the fixed Cerberus model topology.""" + super().__init__() + self.encoder_backbone_name = "resnet34" + self.decoder_info_list = DECODER_KWARGS + self.decoder_info = [64, 64, 128, 256, 512] + + self.backbone = resnet34() + self.conv_map = nn.Conv2d(512, 256, (1, 1), bias=False) + self.decoder_head = nn.ModuleDict() + self.output_head = nn.ModuleDict() + + for decoder_name, output_head in self.decoder_info_list.items(): + if decoder_name not in CONSIDERED_TASKS: + continue + if decoder_name == "Patch-Class": + self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1)) + for output_ch in output_head.values(): + self.decoder_head["Patch-Class"] = nn.Sequential( + OrderedDict( + [ + ("bn1", nn.BatchNorm2d(512, eps=1e-5)), + ("relu1", nn.ReLU(inplace=True)), + ("dropout", nn.Dropout(p=0.3)), + ("conv1", nn.Conv2d(512, 256, 1)), + ("bn2", nn.BatchNorm2d(256, eps=1e-5)), + ("relu2", nn.ReLU(inplace=True)), + ("conv2", nn.Conv2d(256, output_ch, 1)), + ] + ) + ) + continue + + self.decoder_head[decoder_name] = nn.ModuleList( + [ + ConvBlock(256, [256, 128], 3), + ConvBlock(128, [128, 64], 3), + ConvBlock(64, [64, 64], 3), + ConvBlock(64, [64, 64], 3), + ] + ) + decoder_output_head = nn.ModuleDict() + for output_name, output_ch in output_head.items(): + decoder_output_head[output_name] = PytorchBase( + ConvBlock(64, [96], ksize=1), + Conv2d(96, output_ch, ksize=1), + ) + self.output_head[decoder_name] = decoder_output_head + + def forward( + self, + imgs: torch.Tensor, + train_decoder_list: list[str] | None = None, + ) -> OrderedDict: + """Return a dictionary of Cerberus output heads.""" + _ = train_decoder_list + imgs = imgs / 255.0 + feat_list = self.backbone(imgs) + bottom_feats = feat_list[-1] + feat_list[-1] = self.conv_map(bottom_feats) + + output_dict = OrderedDict() + for decoder_name, decoder in self.decoder_head.items(): + if decoder_name == "Patch-Class": + patch_feats = bottom_feats + if patch_feats.shape[-2:] != (9, 9): + patch_feats = cropping_center(patch_feats, (9, 9)) + patch_feats = self.global_avg_pool(patch_feats) + output_dict[decoder_name] = decoder(patch_feats) + continue + + prev_feat = feat_list[-1] + for idx in range(1, len(feat_list)): + prev_feat = functional.interpolate( + prev_feat, + scale_factor=2, + mode="bilinear", + align_corners=False, + ) + prev_feat = decoder[idx - 1](feat_list[-(idx + 1)] + prev_feat) + + decoder_output_head = self.output_head[decoder_name] + for clf_name, clf in decoder_output_head.items(): + output_dict[decoder_name.split("#")[0] + "-" + clf_name] = clf( + prev_feat + ) + + return output_dict diff --git a/tiatoolbox/models/architecture/cerberus/postproc.py b/tiatoolbox/models/architecture/cerberus/postproc.py new file mode 100644 index 000000000..4e0ff8eb7 --- /dev/null +++ b/tiatoolbox/models/architecture/cerberus/postproc.py @@ -0,0 +1,144 @@ +"""Post-processing for the released Cerberus ResNet-34 checkpoint.""" + +from __future__ import annotations + +import cv2 +import numpy as np +from scipy.ndimage import binary_fill_holes, label +from skimage import morphology +from skimage.segmentation import watershed + +CONTOUR_THRESHOLD = 0.5 +GLAND_INNER_THRESHOLD = 0.55 + + +def get_bounding_box(img: np.ndarray) -> tuple[int, int, int, int]: + """Return bounding box as ``rmin, rmax, cmin, cmax``.""" + rows = np.any(img, axis=1) + cols = np.any(img, axis=0) + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + return rmin, rmax + 1, cmin, cmax + 1 + + +class PostProcInstErodedContourMap: + """Cerberus eroded-contour instance post-processing.""" + + @staticmethod + def _proc_gland(inst_fg: np.ndarray, ds_factor: float = 1.0) -> np.ndarray: + """Extract labelled gland instances from inner and contour maps.""" + ksize = int((11 - 1) * ds_factor) + k_disk = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize)) + + inst_inner_raw = inst_fg[..., 0] + inst_cnt_raw = inst_fg[..., 1] + inst_cnt = inst_cnt_raw.copy() + inst_cnt[inst_cnt > CONTOUR_THRESHOLD] = 1 + inst_cnt[inst_cnt <= CONTOUR_THRESHOLD] = 0 + + inst_fg = np.array((inst_inner_raw - inst_cnt) > GLAND_INNER_THRESHOLD) + inst_fg = morphology.remove_small_objects( + inst_fg, + max_size=int(1000 * (ds_factor**2)), + ) + return _dilate_labelled_instances(inst_fg, k_disk) + + @staticmethod + def _proc_lumen(inst_fg: np.ndarray, ds_factor: float = 1.0) -> np.ndarray: + """Extract labelled lumen instances from inner and contour maps.""" + ksize = int((3 - 1) * ds_factor) + k_disk = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize)) + + inst_inner_raw = inst_fg[..., 0] + inst_cnt_raw = inst_fg[..., 1] + inst_cnt = inst_cnt_raw.copy() + inst_cnt[inst_cnt > CONTOUR_THRESHOLD] = 1 + inst_cnt[inst_cnt <= CONTOUR_THRESHOLD] = 0 + + inst_fg = np.array((inst_inner_raw - inst_cnt) > CONTOUR_THRESHOLD) + inst_fg = morphology.remove_small_objects( + inst_fg, + max_size=int(150 * (ds_factor**2)), + ) + return _dilate_labelled_instances(inst_fg, k_disk) + + @staticmethod + def _proc_nuclei(inst_fg: np.ndarray, ds_factor: float = 1.0) -> np.ndarray: + """Extract labelled nuclei instances from inner and contour maps.""" + _ = ds_factor + k_disk = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) + + inst_inner_raw = inst_fg[..., 0] + inst_cnt_raw = inst_fg[..., 1] + inst_raw = inst_inner_raw + inst_cnt_raw + inst_msk = np.array(inst_raw > CONTOUR_THRESHOLD) + + if np.sum(inst_msk) == 0: + return np.zeros(inst_msk.shape) + + inst_msk = cv2.erode(inst_msk.astype("uint8"), k_disk, iterations=1) + inst_msk = label(inst_msk)[0] + inst_msk = morphology.remove_small_objects(inst_msk, max_size=8) + inst_msk = np.array(inst_msk > 0) + + inst_mrk = np.array(inst_inner_raw > CONTOUR_THRESHOLD) + inst_mrk = label(inst_mrk)[0] + inst_mrk = morphology.remove_small_objects(inst_mrk, max_size=4) + + marker = binary_fill_holes(inst_mrk.copy()) + marker = label(marker)[0] + return watershed(-inst_inner_raw, marker, mask=inst_msk) + + @classmethod + def post_process( + cls, + raw_map: np.ndarray, + idx_dict: dict[str, list[int]], + tissue_mode: str, + ds_factor: float = 1.0, + ) -> tuple[np.ndarray, np.ndarray | None]: + """Convert Cerberus raw maps into instance and optional type maps.""" + func_dict = { + "LUMEN": cls._proc_lumen, + "GLAND": cls._proc_gland, + "NUCLEI": cls._proc_nuclei, + } + tissue_key = tissue_mode.upper() + if tissue_key not in func_dict: + msg = f"Unsupported Cerberus tissue mode: {tissue_mode}" + raise ValueError(msg) + + tissue_ch = f"{tissue_mode}-INST" + if tissue_ch not in idx_dict: + msg = f"Missing required Cerberus map: {tissue_ch}" + raise KeyError(msg) + + inst_fg = raw_map[..., idx_dict[tissue_ch][0] : idx_dict[tissue_ch][1]] + inst_map = func_dict[tissue_key](inst_fg, ds_factor) + + type_ch = f"{tissue_mode}-TYPE" + if type_ch not in idx_dict: + return inst_map, None + + type_map = raw_map[..., idx_dict[type_ch][0] : idx_dict[type_ch][1]] + return inst_map, np.squeeze(type_map) + + +def _dilate_labelled_instances(inst_fg: np.ndarray, k_disk: np.ndarray) -> np.ndarray: + """Label foreground instances, dilate each object, and fill holes.""" + inst_lab = label(inst_fg)[0] + output_map = np.zeros(inst_lab.shape) + for inst_id in np.unique(inst_lab).tolist()[1:]: + inst_map = np.array(inst_lab == inst_id, dtype=np.uint8) + y1, y2, x1, x2 = get_bounding_box(inst_map) + pad = k_disk.shape[0] * 2 + y1 = max(y1 - pad, 0) + x1 = max(x1 - pad, 0) + x2 = min(x2 + pad, inst_map.shape[1] - 1) + y2 = min(y2 + pad, inst_map.shape[0] - 1) + inst_map_crop = inst_map[y1:y2, x1:x2] + inst_map_crop = cv2.dilate(inst_map_crop, k_disk, iterations=1) + inst_map_crop = binary_fill_holes(inst_map_crop) + output_region = output_map[y1:y2, x1:x2] + output_region[inst_map_crop > 0] = inst_id + return output_map diff --git a/tiatoolbox/models/architecture/cerberus/utils/__init__.py b/tiatoolbox/models/architecture/cerberus/utils/__init__.py new file mode 100644 index 000000000..a970b748d --- /dev/null +++ b/tiatoolbox/models/architecture/cerberus/utils/__init__.py @@ -0,0 +1 @@ +"""Minimal decoder utilities for Cerberus.""" diff --git a/tiatoolbox/models/architecture/cerberus/utils/conv_layers.py b/tiatoolbox/models/architecture/cerberus/utils/conv_layers.py new file mode 100644 index 000000000..0ed22fe05 --- /dev/null +++ b/tiatoolbox/models/architecture/cerberus/utils/conv_layers.py @@ -0,0 +1,100 @@ +"""Minimal convolution blocks required by the Cerberus ResNet-34 decoder.""" + +from __future__ import annotations + +import torch +from torch import nn + + +class Conv2d(nn.Module): + """Convolution wrapper preserving checkpoint module names.""" + + def __init__( + self, + in_ch: int, + out_ch: int, + ksize: int, + *, + pad: bool = True, + ) -> None: + """Initialize the convolution layer.""" + super().__init__() + pad_size = int(ksize // 2) if pad else 0 + self.conv = nn.Conv2d( + in_ch, + out_ch, + ksize, + stride=1, + padding=pad_size, + bias=True, + ) + + def forward(self, prev_feat: torch.Tensor) -> torch.Tensor: + """Apply convolution.""" + return self.conv(prev_feat) + + +class _ConvLayer(nn.Module): + """Conv-BN-ReLU block used by the released Cerberus decoder.""" + + def __init__( + self, + in_ch: int, + out_ch: int, + ksize: int, + *, + pad: bool = True, + ) -> None: + """Initialize the convolution, batch normalization, and activation.""" + super().__init__() + pad_size = int(ksize // 2) if pad else 0 + self.preact = False + self.bn = nn.BatchNorm2d(out_ch, eps=1e-5) + self.relu = nn.ReLU(inplace=True) + self.conv = nn.Conv2d(in_ch, out_ch, ksize, padding=pad_size, bias=True) + + def forward(self, prev_feat: torch.Tensor) -> torch.Tensor: + """Apply convolution followed by batch norm and ReLU.""" + feat = self.conv(prev_feat) + feat = self.bn(feat) + return self.relu(feat) + + +class ConvBlock(nn.Module): + """A sequence of Cerberus convolution layers.""" + + def __init__( + self, + in_ch: int, + unit_ch: list[int], + ksize: int, + *, + pad: bool = True, + ) -> None: + """Initialize the convolution block.""" + super().__init__() + self.nr_layers = len(unit_ch) + self.block = nn.ModuleList() + for idx in range(self.nr_layers): + self.block.append(_ConvLayer(in_ch, unit_ch[idx], ksize, pad=pad)) + in_ch = unit_ch[idx] + + def forward(self, prev_feat: torch.Tensor) -> torch.Tensor: + """Apply each convolution layer in order.""" + feat = prev_feat + for idx in range(self.nr_layers): + feat = self.block[idx](feat) + return feat + + +class PytorchBase(nn.Module): + """Sequential wrapper preserving original checkpoint key prefix ``x``.""" + + def __init__(self, *args: nn.Module) -> None: + """Initialize the sequential wrapper.""" + super().__init__() + self.x = nn.Sequential(*args) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply wrapped modules.""" + return self.x(x) diff --git a/tiatoolbox/models/engine/io_config.py b/tiatoolbox/models/engine/io_config.py index b3af1e557..bab1357d7 100644 --- a/tiatoolbox/models/engine/io_config.py +++ b/tiatoolbox/models/engine/io_config.py @@ -233,6 +233,10 @@ class IOSegmentorConfig(ModelIOConfigABC): Resolution to save all output. tile_shape (tuple(int, int)): Tile shape to process the WSI. + postproc_halo (int | tuple[int, int]): + Optional extra context around each post-processing tile. If set, the + engine post-processes an expanded tile and keeps objects owned by the + original tile core. Attributes: input_resolutions (list(dict)): @@ -257,6 +261,10 @@ class IOSegmentorConfig(ModelIOConfigABC): Tile shape to process the WSI. margin (int): Tile margin to accumulate the output. + postproc_halo (int | tuple[int, int]): + Optional extra context around each post-processing tile. If set, the + engine post-processes an expanded tile and keeps objects owned by the + original tile core. Examples: >>> # Defining io for a network having 1 input and 1 output at the @@ -294,6 +302,7 @@ class IOSegmentorConfig(ModelIOConfigABC): save_resolution: dict = None tile_shape: tuple[int, int] | None = None margin: int | None = None + postproc_halo: int | tuple[int, int] | None = None def to_baseline(self: IOSegmentorConfig) -> IOSegmentorConfig: """Returns a new config object converted to baseline form. @@ -389,6 +398,10 @@ class IOInstanceSegmentorConfig(IOSegmentorConfig): Tile margin to accumulate the output. tile_shape (tuple(int, int)): Tile shape to process the WSI. + postproc_halo (int | tuple[int, int]): + Optional extra context around each post-processing tile. If set, the + engine post-processes an expanded tile and keeps objects owned by the + original tile core. Attributes: input_resolutions (list(dict)): @@ -413,6 +426,10 @@ class IOInstanceSegmentorConfig(IOSegmentorConfig): Tile margin to accumulate the output. tile_shape (tuple(int, int)): Tile shape to process the WSI. + postproc_halo (int | tuple[int, int]): + Optional extra context around each post-processing tile. If set, the + engine post-processes an expanded tile and keeps objects owned by the + original tile core. Examples: >>> # Defining io for a network having 1 input and 1 output at the diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index f183f325c..049682427 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -204,6 +204,10 @@ class MultiTaskSegmentorRunParams(SemanticSegmentorRunParams, total=False): Number of workers used in DataLoader. output_file (str): Output file name for saving results (e.g., .zarr or .db). + postproc_halo (int | tuple[int, int]): + Optional halo around each WSI post-processing tile. When set, tile-mode + post-processing runs on the expanded tile and keeps objects owned by + the original tile core. output_resolutions (Resolution): Resolution used for writing output predictions. patch_output_shape (tuple[int, int]): @@ -224,6 +228,7 @@ class MultiTaskSegmentorRunParams(SemanticSegmentorRunParams, total=False): """ + postproc_halo: int | tuple[int, int] return_predictions: tuple[bool, ...] @@ -949,11 +954,15 @@ def post_process_wsi( # skipcq: PYL-R0201 if self.num_workers == 0 else self.num_workers ) + postproc_halo = kwargs.get("postproc_halo") + if postproc_halo is None: + postproc_halo = getattr(self._ioconfig, "postproc_halo", None) post_process_predictions = self._process_tile_mode( probabilities=probabilities, save_path=save_path.with_suffix(".zarr"), memory_threshold=kwargs.get("memory_threshold", 80), return_predictions=kwargs.get("return_predictions"), + postproc_halo=postproc_halo, ) else: post_process_predictions = self._process_full_wsi( @@ -1083,6 +1092,7 @@ def _process_tile_mode( memory_threshold: float = 80, *, return_predictions: tuple[bool, ...] | None = None, + postproc_halo: int | tuple[int, int] | None = None, ) -> tuple[dict, ...] | None: """Convert WSI probability maps into outputs using tile-mode processing. @@ -1120,6 +1130,11 @@ def _process_tile_mode( prediction arrays are retained (i.e., they are set to ``None`` and not allocated). The tuple length must match the number of task dictionaries produced by ``postproc_func``. + postproc_halo (int | tuple[int, int] | None): + Optional halo around each tile before post-processing. Tuple values + follow image-shape order ``(height, width)``. With a non-zero halo, + only core grid tiles are processed; objects are kept if owned by the + unexpanded tile core. Returns: list[dict] | None: @@ -1159,16 +1174,21 @@ def _process_tile_mode( # assume ioconfig has already been converted to `baseline` for `tile` mode wsi_proc_shape = wsi_reader.slide_dimensions(**highest_input_resolution) - masked_output_shape = ( - self.mask_bounds[2] - self.mask_bounds[0], # X/row - self.mask_bounds[3] - self.mask_bounds[1], # Y/col - ) + # Tile over the actual probability canvas, which may be larger than the + # mask bounding box because inference keeps whole patch-output regions. + masked_output_shape = np.array(probabilities[0].shape[:2][::-1]) # * retrieve tile placement and tile info flag # tile shape will always be corrected to be multiple of output tile_info_sets = self._get_tile_info( image_shape=masked_output_shape, wsi_proc_shape=wsi_proc_shape ) + postproc_halo_xy = _normalise_postproc_halo(postproc_halo) + use_postproc_halo = np.any(postproc_halo_xy > 0) + if use_postproc_halo: + tile_info_sets = [ + [tile_info_sets[0][0], np.zeros_like(tile_info_sets[0][1])] + ] ioconfig = self._ioconfig.to_baseline() tile_metadata = _build_tile_tasks( @@ -1182,7 +1202,8 @@ def _process_tile_mode( # Calculate batch size for dask compute vm = psutil.virtual_memory() bytes_per_element = np.dtype(probabilities[0].dtype).itemsize - tile_elements = np.prod(self._ioconfig.tile_shape) + tile_shape = np.array(self._ioconfig.tile_shape) + tile_elements = np.prod(tile_shape + (2 * postproc_halo_xy[::-1])) prod_dim2 = math.prod(p.shape[2] for p in probabilities if len(p.shape) > 2) # noqa: PLR2004 tile_memory = len(probabilities) * tile_elements * prod_dim2 * bytes_per_element # available memory @@ -1198,17 +1219,27 @@ def _process_tile_mode( disable=not self.verbose, ): tile_metadata_ = tile_metadata[i : i + batch_size] + tile_read_bounds = [ + _get_postproc_tile_read_bounds( + tile_bounds=tile_meta[0], + postproc_halo_xy=postproc_halo_xy, + image_shape=masked_output_shape, + ) + for tile_meta in tile_metadata_ + ] # Build delayed tasks delayed_tasks = [ self._compute_tile( - _tile_meta[0], + tile_read_bounds[_tile_id], ) - for _tile_meta in tqdm( - tile_metadata_, - leave=False, - desc="Creating list of delayed tasks for post-processing", - disable=not self.verbose, + for _tile_id, _ in enumerate( + tqdm( + tile_metadata_, + leave=False, + desc="Creating list of delayed tasks for post-processing", + disable=not self.verbose, + ) ) ] @@ -1232,6 +1263,13 @@ def _process_tile_mode( # Merge each tile result for _tile_id, post_process_output in enumerate(tqdm_loop): tile_bounds, tile_flag, tile_mode = tile_metadata_[_tile_id] + tile_read_bounds_ = tile_read_bounds[_tile_id] + if use_postproc_halo: + post_process_output = _crop_halo_post_process_output( # noqa: PLW2901 + post_process_output=post_process_output, + tile_bounds=tile_bounds, + tile_read_bounds=tile_read_bounds_, + ) # create a list of info dict for each task wsi_info_dict = _create_wsi_info_dict( @@ -2608,19 +2646,24 @@ def merge_multitask_vertical_chunkwise( chunk_shape=chunk_shape, probabilities_zarr=probabilities_zarr[idx], probabilities_da=probabilities_da[idx], - zarr_group=zarr_group, + zarr_group=( + zarr_group if probabilities_zarr[idx] is not None else None + ), name=f"probabilities/{idx}", ) - probabilities_zarr, probabilities_da = _save_multitask_vertical_to_cache( - probabilities_zarr=probabilities_zarr, - probabilities_da=probabilities_da, - probabilities=probabilities, - idx=idx, - tqdm_loop=tqdm_loop, - save_path=save_path, - chunk_shape=chunk_shape, - memory_threshold=memory_threshold, + probabilities_zarr, probabilities_da, zarr_group = ( + _save_multitask_vertical_to_cache( + probabilities_zarr=probabilities_zarr, + probabilities_da=probabilities_da, + zarr_group=zarr_group, + probabilities=probabilities, + idx=idx, + tqdm_loop=tqdm_loop, + save_path=save_path, + chunk_shape=chunk_shape, + memory_threshold=memory_threshold, + ) ) if next_chunk is not None: @@ -2647,13 +2690,14 @@ def merge_multitask_vertical_chunkwise( def _save_multitask_vertical_to_cache( probabilities_zarr: list[zarr.Array] | list[None], probabilities_da: list[da.Array] | list[None], + zarr_group: zarr.Group | None, probabilities: np.ndarray, idx: int, tqdm_loop: tqdm, save_path: Path, chunk_shape: tuple, memory_threshold: int = 80, -) -> tuple[list[zarr.Array], list[da.Array] | None]: +) -> tuple[list[zarr.Array], list[da.Array] | None, zarr.Group | None]: """Helper function to save to zarr if vertical merge is out of memory.""" used_percent = 0 if probabilities_da[idx] is not None: @@ -2669,7 +2713,8 @@ def _save_multitask_vertical_to_cache( f"Saving intermediate results to disk." ) update_tqdm_desc(tqdm_loop=tqdm_loop, desc=msg) - zarr_group = zarr.open(str(save_path), mode="a") + if zarr_group is None: + zarr_group = zarr.open(str(save_path), mode="a") probabilities_zarr[idx] = zarr_group.create_array( name=f"probabilities/{idx}", shape=probabilities_da[idx].shape, @@ -2681,7 +2726,7 @@ def _save_multitask_vertical_to_cache( update_tqdm_desc(tqdm_loop=tqdm_loop, desc=desc) probabilities_da[idx] = None - return probabilities_zarr, probabilities_da + return probabilities_zarr, probabilities_da, zarr_group def _clear_zarr( @@ -3189,15 +3234,26 @@ def _update_tile_based_predictions_array( continue max_h, max_w = wsi_info_dict[idx]["predictions"].shape - x_end, y_end = min(x_end, max_w), min(y_end, max_h) + predictions = post_process_output_["predictions"] + tile_h, tile_w = predictions.shape[:2] + x_end_, y_end_ = ( + min(x_end, max_w, x_start + tile_w), + min( + y_end, + max_h, + y_start + tile_h, + ), + ) + if x_end_ <= x_start or y_end_ <= y_start: + continue new_predictions_ = post_process_output_["predictions"][ - 0 : y_end - y_start, 0 : x_end - x_start + 0 : y_end_ - y_start, 0 : x_end_ - x_start ] # Update instance values if post_process_output_["seg_type"] == "instance": previous_predictions_ = wsi_info_dict[idx]["predictions"][ - y_start:y_end, x_start:x_end + y_start:y_end_, x_start:x_end_ ] overlap = (new_predictions_ > 0) & (previous_predictions_ > 0) max_inst_value = 0 if max_inst_value is None else max_inst_value @@ -3219,7 +3275,7 @@ def _update_tile_based_predictions_array( else max_inst_value ) - wsi_info_dict[idx]["predictions"][y_start:y_end, x_start:x_end] = ( + wsi_info_dict[idx]["predictions"][y_start:y_end_, x_start:x_end_] = ( new_predictions_ ) @@ -3280,6 +3336,201 @@ def _build_tile_tasks( return tile_metadata +def _normalise_postproc_halo( + postproc_halo: int | tuple[int, int] | list[int] | np.ndarray | None, +) -> np.ndarray: + """Return post-processing halo in ``(x, y)`` order.""" + if postproc_halo is None: + return np.array([0, 0], dtype=np.int32) + + halo = np.asarray(postproc_halo, dtype=np.int32) + if halo.ndim == 0: + halo = np.repeat(halo, 2) + + if halo.shape != (2,): + msg = "`postproc_halo` must be an int or a length-2 sequence." + raise ValueError(msg) + + if np.any(halo < 0): + msg = "`postproc_halo` must be non-negative." + raise ValueError(msg) + + # Public tuple convention follows image shape order: (height, width). + return halo[::-1] + + +def _get_postproc_tile_read_bounds( + tile_bounds: tuple[int, int, int, int] | np.ndarray, + postproc_halo_xy: np.ndarray, + image_shape: tuple[int, int] | np.ndarray, +) -> tuple[int, int, int, int]: + """Expand tile bounds by halo and clip to the processed image shape.""" + tile_bounds = np.asarray(tile_bounds, dtype=np.int32) + image_shape = np.asarray(image_shape, dtype=np.int32) + read_tl = np.maximum(tile_bounds[:2] - postproc_halo_xy, 0) + read_br = np.minimum(tile_bounds[2:] + postproc_halo_xy, image_shape) + return tuple(np.concatenate([read_tl, read_br]).tolist()) + + +def _crop_halo_post_process_output( + post_process_output: tuple[dict, ...], + tile_bounds: tuple[int, int, int, int] | np.ndarray, + tile_read_bounds: tuple[int, int, int, int] | np.ndarray, +) -> tuple[dict, ...]: + """Crop halo-expanded post-processing output back to the tile core.""" + tile_bounds = np.asarray(tile_bounds, dtype=np.int32) + tile_read_bounds = np.asarray(tile_read_bounds, dtype=np.int32) + core_tl_in_read = tile_bounds[:2] - tile_read_bounds[:2] + core_br_in_read = tile_bounds[2:] - tile_read_bounds[:2] + + cropped_outputs = [] + for output in post_process_output: + output_ = output.copy() + + if "predictions" in output_: + output_["predictions"] = output_["predictions"][ + core_tl_in_read[1] : core_br_in_read[1], + core_tl_in_read[0] : core_br_in_read[0], + ] + + if "info_dict" in output_: + keep_mask = _get_halo_core_ownership_mask( + info_dict=output_["info_dict"], + core_tl_in_read=core_tl_in_read, + core_br_in_read=core_br_in_read, + ) + output_["info_dict"] = _filter_and_shift_halo_info_dict( + info_dict=output_["info_dict"], + keep_mask=keep_mask, + offset=-core_tl_in_read, + ) + + cropped_outputs.append(output_) + + return tuple(cropped_outputs) + + +def _get_halo_core_ownership_mask( + info_dict: dict, + core_tl_in_read: np.ndarray, + core_br_in_read: np.ndarray, +) -> np.ndarray: + """Return mask for objects owned by the unexpanded tile core.""" + instance_count = _get_info_dict_instance_count(info_dict) + if instance_count == 0: + return np.zeros(0, dtype=bool) + + points = _get_info_dict_ownership_points(info_dict) + if points is None: + return np.ones(instance_count, dtype=bool) + + return ( + (points[:, 0] >= core_tl_in_read[0]) + & (points[:, 0] < core_br_in_read[0]) + & (points[:, 1] >= core_tl_in_read[1]) + & (points[:, 1] < core_br_in_read[1]) + ) + + +def _get_info_dict_instance_count(info_dict: dict) -> int: + """Return the number of instances represented by an info dictionary.""" + for value in info_dict.values(): + if value is not None: + return len(value) + return 0 + + +def _get_info_dict_ownership_points(info_dict: dict) -> np.ndarray | None: + """Get representative points for ownership checks.""" + if "centroid" in info_dict: + return np.asarray(info_dict["centroid"], dtype=np.float32) + + if "box" in info_dict: + boxes = np.asarray(info_dict["box"], dtype=np.float32) + if boxes.size == 0: + return np.empty((0, 2), dtype=np.float32) + return (boxes[:, :2] + boxes[:, 2:]) / 2 + + if "contours" not in info_dict: + return None + + contours = np.asarray(info_dict["contours"]) + if contours.size == 0: + return np.empty((0, 2), dtype=np.float32) + + points = [] + pad_value = ( + np.iinfo(contours.dtype).min + if np.issubdtype(contours.dtype, np.integer) + else np.nan + ) + for contour in contours: + valid_mask = _get_valid_coordinate_rows(contour, pad_value) + valid_contour = contour[valid_mask] + if len(valid_contour) == 0: + points.append([np.nan, np.nan]) + continue + points.append( + ((valid_contour.min(axis=0) + valid_contour.max(axis=0)) / 2).tolist() + ) + return np.asarray(points, dtype=np.float32) + + +def _filter_and_shift_halo_info_dict( + info_dict: dict, + keep_mask: np.ndarray, + offset: np.ndarray, +) -> dict: + """Filter halo post-processing objects and shift coordinates to core space.""" + return { + key: _shift_halo_info_field( + key=key, + value=np.asarray(value)[keep_mask], + offset=offset, + ) + for key, value in info_dict.items() + } + + +def _shift_halo_info_field( + key: str, + value: np.ndarray, + offset: np.ndarray, +) -> np.ndarray: + """Shift geometric info fields from expanded-tile to core-tile coordinates.""" + if key == "box": + return value + np.array([offset[0], offset[1], offset[0], offset[1]]) + + if key == "centroid": + return value + offset + + if key != "contours": + return value + + contours = value.copy() + if contours.size == 0: + return contours + + pad_value = ( + np.iinfo(contours.dtype).min + if np.issubdtype(contours.dtype, np.integer) + else np.nan + ) + valid_mask = _get_valid_coordinate_rows(contours, pad_value) + contours[valid_mask] = (contours[valid_mask] + offset).astype(contours.dtype) + return contours + + +def _get_valid_coordinate_rows( + coordinates: np.ndarray, + pad_value: float, +) -> np.ndarray: + """Return rows that are not contour padding.""" + if np.isnan(pad_value): + return ~np.isnan(coordinates).all(axis=-1) + return ~(coordinates == pad_value).all(axis=-1) + + def _compute_info_dict_for_merge( inst_dict: dict, tile_mode: int, @@ -3863,6 +4114,7 @@ def _post_save_json_store( save_path: Path | None, **kwargs: Unpack[MultiTaskSegmentorRunParams], ) -> None: + """Clean temporary JSON-store data and report unsupported probability saves.""" for key in keys_to_compute: del processed_predictions[key] diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index c82f8ac71..6f4f4283f 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -40,6 +40,15 @@ def load_torch_model(model: nn.Module, weights: str | Path) -> nn.Module: # ! assume to be saved in single GPU mode # always load on to the CPU saved_state_dict = torch.load(weights, map_location="cpu") + saved_state_dict = ( + saved_state_dict["desc"] + if isinstance(saved_state_dict, dict) and "desc" in saved_state_dict + else saved_state_dict + ) + if all(k.split(".")[0] == "module" for k in saved_state_dict): + saved_state_dict = { + ".".join(k.split(".")[1:]): v for k, v in saved_state_dict.items() + } model.load_state_dict(saved_state_dict, strict=True) return model diff --git a/tiatoolbox/wsicore/wsireader.py b/tiatoolbox/wsicore/wsireader.py index 633ef4c75..d6f079de8 100644 --- a/tiatoolbox/wsicore/wsireader.py +++ b/tiatoolbox/wsicore/wsireader.py @@ -1786,6 +1786,8 @@ def tissue_mask( from tiatoolbox.tools import tissuemask # noqa: PLC0415 thumbnail = self.slide_thumbnail(resolution, units) + # set any black pixels to white to avoid them being included in the tissue mask + thumbnail[thumbnail.sum(axis=2) == 0] = 255 if method not in ["otsu", "morphological"]: msg = f"Invalid tissue masking method: {method}." raise ValueError(msg)