diff --git a/.assets/3dcellulus.webp b/.assets/3dcellulus.webp new file mode 100644 index 0000000..2454450 Binary files /dev/null and b/.assets/3dcellulus.webp differ diff --git a/.assets/autospem.webp b/.assets/autospem.webp new file mode 100644 index 0000000..0d96a68 Binary files /dev/null and b/.assets/autospem.webp differ diff --git a/README.md b/README.md index 143fec1..64a57f4 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,49 @@ -# cellulus -Self-supervised instance segmentation of cells + +# Unsupervised Learning of Object-Centric Embeddings for Cell Instance Segmentation in Microscopy Images + +*Algorithm for unsupervised cell instance segmentation.* We present a self-supervised learning method for object-centric embeddings (OCEs) which embed image patches such that the spatial offsets between patches cropped from the same object are preserved. Those learnt embeddings can be used to delineate individual objects and thus obtain instance segmentations. The method relies on the assumptions (commonly found in microscopy images) that objects have a similar appearance and are randomly distributed in the image. If ground-truth annotations are available, this method serves as an excellent starting point for supervised training, reducing the required amount of ground-truth needed + +![](.assets/autospem.webp) + +## Requirements and Setup + +Install the required packages with conda +``` +conda create --name autospem --file environment.yml +``` + +## Train Spatial Instance Embedding Networks + + +``` +python colocseg/train_ssl.py --shape 252 252 --in_channels 2 --out_channels 2 --dspath --initial_lr 4e-05 --output_shape 236 236 --positive_radius 10 --regularization 1e-05 --check_val_every_n_epoch 10 --limit_val_batches 256 --max_epochs 50 --temperature 10 --lr_milestones 20 30 --batch_size 8 --loader_workers 8 --gpu 1 +``` + +## Infer Mean and Std of Spatial Embeddings + +``` +python colocseg/infer_spatial_embeddings.py /model.torch output.zarr spatial_embedding /tissuenet_v1.0_test.npz 102 raw 2 32 transpose +``` + +## Infer Segmentation from Spatial Embedding + +``` +python colocseg/infer_pseudo_gt_from_mean_std.py output.zarr /tissuenet_v1.0_test.npz spatial_embedding meanshift_segmentation 0 0.21 +``` +## Postprocess Embeddings (Shrinking Instances by Fixed Distance) + +``` +python scripts/postprocess.py output.zarr meanshift_segmentation +``` + + + +## External Datasets + +Models were trained on cell segmentation datasets that are part of the [tissuenet dataset](https://datasets.deepcell.org/) and the [cell tracking challenge datasets](http://celltrackingchallenge.net/2d-datasets/) + +## 3D Segmentation + +![](.assets/3dcellulus.webp) +> Fully unsupervised 3D segmentation with no prior training + diff --git a/colocseg/__init__.py b/colocseg/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/colocseg/datamodules.py b/colocseg/datamodules.py new file mode 100644 index 0000000..cd40d45 --- /dev/null +++ b/colocseg/datamodules.py @@ -0,0 +1,331 @@ +import argparse +import pytorch_lightning as pl +import torch +import imgaug as ia +import os +import json +import random +import zarr +from torch.utils.data import DataLoader, ConcatDataset, Subset +from colocseg.datasets import (CoordinateDataset, TissueNetDataset, ZarrSegmentationDataset, + EvenlyMixedDataset, ZarrRegressionDataset, ZarrImageDataset) +import numpy as np + + +class AnchorDataModule(pl.LightningDataModule): + + def __init__(self, batch_size, dspath, + shape=(256, 256), output_shape=(240, 240), + loader_workers=10, positive_radius=32): + + super().__init__() + self.batch_size = batch_size + self.dspath = dspath + self.shape = tuple(int(_) for _ in shape) + self.output_shape = tuple(int(_) for _ in output_shape) + self.loader_workers = loader_workers + output_shape + self.positive_radius = positive_radius + + def setup_datasets(self): + raise NotImplementedError() + + def setup(self, stage=None): + + img_ds_train, img_ds_val = self.setup_datasets() + + self.ds_train = CoordinateDataset( + img_ds_train, + self.output_shape, + self.positive_radius, + density=0.1, + return_segmentation=False) + + self.ds_val = CoordinateDataset( + img_ds_val, + self.output_shape, + self.positive_radius, + density=0.1, + return_segmentation=True) + + return (img_ds_train, img_ds_val), (self.ds_train, self.ds_val) + + def train_dataloader(self): + return DataLoader(self.ds_train, + shuffle=True, + batch_size=self.batch_size, + num_workers=self.loader_workers, + drop_last=False) + + def val_dataloader(self): + return DataLoader(self.ds_val, + batch_size=1, + shuffle=False, + num_workers=2, + drop_last=False) + + def test_dataloader(self): + return None + + @staticmethod + def add_model_specific_args(parent_parser): + parser = argparse.ArgumentParser( + parents=[parent_parser], add_help=False) + try: + parser.add_argument('--batch_size', type=int, default=8) + except argparse.ArgumentError: + pass + parser.add_argument('--loader_workers', type=int, default=8) + parser.add_argument('--dspath', type=str) + parser.add_argument('--shape', nargs='*', default=(256, 256)) + parser.add_argument('--output_shape', nargs='*', default=(256, 256)) + parser.add_argument('--positive_radius', type=int, default=64) + + return parser + + +class TissueNetDataModule(AnchorDataModule): + + def setup_datasets(self): + + train_ds = TissueNetDataset(os.path.join(self.dspath, "tissuenet_v1.0_train.npz"), + crop_to=self.shape, + augment=True) + + val_ds = TissueNetDataset(os.path.join(self.dspath, "tissuenet_v1.0_val.npz"), + crop_to=(252, 252), + augment=False) + + return ConcatDataset([train_ds, val_ds]), val_ds + + +class CTCDataModule(AnchorDataModule): + + def setup_datasets(self): + + train_ds = ZarrImageDataset(self.dspath, + "train", + crop_to=self.shape, + augment=True) + + val_ds = ZarrImageDataset(self.dspath, + "train", + crop_to=(252, 252), + augment=False) + + return train_ds, train_ds + +class PartiallySupervisedDataModule(pl.LightningDataModule): + + def __init__(self, batch_size, tissuenet_root, + pseudo_gt_root, + raw_key, + gt_key, + pseudo_gt_key, + pseudo_gt_val_key=None, + pseudo_gt_score_key=None, + target_transform="stardist", + target_transform_aux="stardist", + corrected_instances_file=None, + tissue_type=None, + target_type="cell", + shape=(256, 256), + output_shape=(240, 240), + limit=None, + datasetseed=42, + loader_workers=10): + + super().__init__() + self.batch_size = batch_size + self.tissuenet_root = tissuenet_root + self.pseudo_gt_root = pseudo_gt_root + self.raw_key = raw_key + self.gt_key = gt_key + self.pseudo_gt_key = pseudo_gt_key + + if pseudo_gt_val_key is None: + self.pseudo_gt_val_key = pseudo_gt_key + else: + self.pseudo_gt_val_key = pseudo_gt_val_key + self.pseudo_gt_score_key = pseudo_gt_score_key + + self.target_transform = target_transform + self.target_transform_aux = target_transform_aux + self.corrected_instances_file = corrected_instances_file + self.shape = tuple(int(_) for _ in shape) + self.output_shape = tuple(int(_) for _ in output_shape) + self.loader_workers = loader_workers + self.limit = limit + self.seed = datasetseed + self.tissue_type = tissue_type + self.target_type = target_type + + def setup_datasets(self): + + aux_is_segmentation = self.target_transform_aux is not None + + if aux_is_segmentation: + val_ds = ZarrSegmentationDataset(self.pseudo_gt_root, + {"raw": "val/raw", + "gt": f"val/{self.pseudo_gt_val_key}", + "tissue_list": "val/tissue_list"}, + tissue_type=self.tissue_type, + target_type="cell", + crop_to=(252, 252), + augment=False, + smooth_boundaries=True) + else: + val_ds = ZarrRegressionDataset(self.pseudo_gt_root, + {"raw": "val/raw", + "target": f"val/{self.pseudo_gt_val_key}", + "tissue_list": "val/tissue_list"}, + tissue_type=self.tissue_type, + target_type="cell", + crop_to=(252, 252), + augment=False) + + corrected_instances = None + if self.corrected_instances_file is not None: + if self.corrected_instances_file == 'all': + + train_ds = ZarrSegmentationDataset(self.pseudo_gt_root, + {"raw": f"train/{self.raw_key}", + "gt": f"train/{self.gt_key}", + "correction": f"train/{self.gt_key}", + "tissue_list": "train/tissue_list"}, + tissue_type=self.tissue_type, + target_type=self.target_type, + target_transform=self.target_transform_aux, + crop_to=self.shape, + augment=True, + smooth_boundaries=True) + + return train_ds, train_ds, val_ds + else: + with open(self.corrected_instances_file, 'r') as fp: + def toint(x): + return {int(k): v for k, v in x} + corrected_instances = json.load(fp, object_pairs_hook=toint) + + train_ds = ZarrSegmentationDataset(self.pseudo_gt_root, + {"raw": f"train/{self.raw_key}", + "gt": f"train/{self.gt_key}", + "correction": f"train/{self.gt_key}", + "tissue_list": "train/tissue_list"}, + tissue_type=self.tissue_type, + target_type=self.target_type, + target_transform=self.target_transform_aux, + corrected_instances=corrected_instances, + limit_to_correction=True, + crop_to=self.shape, + augment=True, + smooth_boundaries=True) + + if aux_is_segmentation: + train_pseudo_ds = ZarrSegmentationDataset(self.pseudo_gt_root, + {"raw": f"train/{self.raw_key}", + "gt": f"train/{self.pseudo_gt_key}", + "correction": f"train/{self.gt_key}", + "tissue_list": "train/tissue_list"}, + tissue_type=self.tissue_type, + target_type="cell", + target_transform=self.target_transform_aux, + corrected_instances=corrected_instances, + limit_to_correction=False, + crop_to=self.shape, + augment=True, + smooth_boundaries=True) + + else: + train_pseudo_ds = ZarrRegressionDataset(self.pseudo_gt_root, + {"raw": f"train/{self.raw_key}", + "target": f"train/{self.pseudo_gt_key}", + "tissue_list": "train/tissue_list"}, + tissue_type=self.tissue_type, + target_type="cell", + crop_to=self.shape, + augment=True) + + return train_ds, train_pseudo_ds, val_ds + + def setup(self, stage=None): + + ds_train, ds_pseudo, ds_val = self.setup_datasets() + self.ds_train = ds_train + self.ds_pseudo = ds_pseudo + self.ds_val = ds_val + + def train_dataloader(self): + + if self.limit is None: + mixed_train_ds = EvenlyMixedDataset([self.ds_train, self.ds_pseudo], + [self.batch_size // 2, self.batch_size // 2]) + else: + assert(len(self.ds_train) == len(self.ds_pseudo)) + assert(len(self.ds_train) >= self.limit) + + if self.pseudo_gt_score_key is None: + supervised_indices = np.random.RandomState(seed=self.seed).permutation(len(self.ds_train))[:self.limit] + remaining_indices = np.random.RandomState(seed=self.seed).permutation(len(self.ds_train))[self.limit:] + else: + z = zarr.open(self.pseudo_gt_root, "r") + scores = z[self.pseudo_gt_score_key][:] + assert(z[f"train/{self.pseudo_gt_key}"].shape[0] == scores.shape[0]) + sorted_indices = np.argsort(scores) + supervised_indices = sorted_indices[:self.limit] + remaining_indices = sorted_indices[self.limit:] + + supervised_limited_train_ds = Subset(self.ds_train, supervised_indices) + remainin_pseudo_ds = Subset(self.ds_pseudo, remaining_indices) + supervised_plus_pseudo = ConcatDataset([supervised_limited_train_ds, remainin_pseudo_ds]) + mixed_train_ds = EvenlyMixedDataset([supervised_limited_train_ds, supervised_plus_pseudo], + [self.batch_size // 2, self.batch_size // 2]) + + def seed_worker(worker_id): + worker_seed = torch.initial_seed() % 2**32 + ia.seed(worker_seed) + np.random.seed(worker_seed) + random.seed(worker_seed) + + return DataLoader(mixed_train_ds, + batch_size=None, + worker_init_fn=seed_worker, + num_workers=self.loader_workers) + + def val_dataloader(self): + return DataLoader(self.ds_val, + batch_size=1, + shuffle=False, + num_workers=2, + drop_last=False) + + def test_dataloader(self): + return None + + @ staticmethod + def add_model_specific_args(parent_parser): + parser = argparse.ArgumentParser( + parents=[parent_parser], add_help=False) + try: + parser.add_argument('--batch_size', type=int, default=8) + except argparse.ArgumentError: + pass + parser.add_argument('--loader_workers', type=int, default=8) + parser.add_argument('--tissuenet_root', type=str) + parser.add_argument('--target_transform', type=str, default="stardist") + parser.add_argument('--target_transform_aux', type=str, default=None) + parser.add_argument('--corrected_instances_file', type=str, default=None) + parser.add_argument('--pseudo_gt_root', type=str) + parser.add_argument('--gt_key', type=str, default="gt") + parser.add_argument('--raw_key', type=str, default="raw") + parser.add_argument('--pseudo_gt_key', type=str) + parser.add_argument('--pseudo_gt_val_key', type=str, default=None) + parser.add_argument('--target_type', type=str, default="cell") + parser.add_argument('--tissue_type', type=str, default=None) + parser.add_argument('--shape', nargs='*', default=(256, 256)) + parser.add_argument('--output_shape', nargs='*', default=(256, 256)) + parser.add_argument('--limit', type=int, default=None) + parser.add_argument('--datasetseed', type=int, default=42) + parser.add_argument('--pseudo_gt_score_key', type=str, default=None) + + return parser diff --git a/colocseg/datasets.py b/colocseg/datasets.py new file mode 100644 index 0000000..ba27449 --- /dev/null +++ b/colocseg/datasets.py @@ -0,0 +1,496 @@ +from colocseg.utils import get_augmentation_transform, smooth_boundary_fn, sizefilter, CropAndSkipIgnore +from colocseg.transforms import AffinityTf, StardistTf, ThreeclassTf, CellposeTf +from torch.utils.data.dataset import Dataset +import numpy as np +import torch +import random +from scipy.ndimage.morphology import distance_transform_edt +from imgaug import augmenters as iaa +import zarr + + +class CoordinateDataset(Dataset): + + def __init__(self, + dataset, + output_shape, + positive_radius, + density=0.2, + return_segmentation=True): + """ + max_imbalance_dist: int + a patch on the edge of the image has an imabalanced + set of neighbours. We compute the distance to the averge + neigbour (zero for patches in the center) + If the average vector is longer than + max_imbalance_dist the connection is removed + """ + + self.root_dataset = dataset + self.return_segmentation = return_segmentation + self.density = density + self.output_shape = tuple(int(_) for _ in output_shape) + self.unbiased_shape = tuple(int(_ - (2 * positive_radius)) for _ in output_shape) + self.positive_radius = int(positive_radius) + + def sample_offsets_within_radius(self, radius, n_offsets): + + # sample 2 times more samples in a square + # so that we have room to filter out the ones outside the circle + offs_x = np.random.randint(-radius, radius + 1, + size=2 * n_offsets) + offs_y = np.random.randint(-radius, radius + 1, + size=2 * n_offsets) + + offs_coord = np.stack((offs_x, offs_y), axis=1) + in_circle = (offs_coord**2).sum(axis=1) < radius ** 2 + offs_coord = offs_coord[in_circle] + not_zero = np.absolute(offs_coord).sum(axis=1) > 0 + offs_coord = offs_coord[not_zero] + + if len(offs_coord) < n_offsets: + # if the filter removed more than half of all samples + # repeat + return self.sample_offsets_within_radius(radius, n_offsets) + + return offs_coord[:n_offsets] + + def sample_coordinates(self): + """samples pairs of corrdinates within outputshape. + Returns: + anchor_samples: uniformly sampled coordinates where all points within + the positive radius are still within the outputshape + anchor_samples.shape = (p, 2) + reference_samples: random coordinates in full outputshape space. + all_samples.shape = (p, 2) + """ + + n_anchor = self.get_num_anchors() + n_reference = self.get_num_references() + + anchor_coord_x = np.random.randint(self.positive_radius, + self.output_shape[0] - self.positive_radius + 1, + size=n_anchor) + anchor_coord_y = np.random.randint(self.positive_radius, + self.output_shape[1] - self.positive_radius + 1, + size=n_anchor) + anchor_coord = np.stack((anchor_coord_x, anchor_coord_y), axis=1) + + anchor_samples = np.repeat(anchor_coord, n_reference, axis=0) + + offset_in_pos_radius = self.sample_offsets_within_radius(self.positive_radius, + len(anchor_samples)) + refernce_samples = anchor_samples + offset_in_pos_radius + + return anchor_samples, refernce_samples + + def get_num_anchors(self): + return int(self.density * self.unbiased_shape[0] * self.unbiased_shape[1]) + + def get_num_references(self): + return int(self.density * self.positive_radius**2 * np.pi) + + def get_num_samples(self): + return self.get_num_anchors() * self.get_num_references() + + def unpack(self, sample): + if isinstance(sample, tuple): + if len(sample) == 2: + x, y = sample + else: + x = sample[0] + y = 0. + else: + x = sample + y = 0. + + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + + return x, y + + def __len__(self): + return len(self.root_dataset) + + def __getitem__(self, index): + + x, y = self.unpack(self.root_dataset[index]) + + anchor_coordinates, refernce_coordinates = self.sample_coordinates() + if self.return_segmentation: + return x, anchor_coordinates, refernce_coordinates, y + else: + return x, anchor_coordinates, refernce_coordinates + + +class TissueNetDataset(Dataset): + + def __init__(self, + data_file, + tissue_type=None, + target_type="cell", + target_transform=None, + crop_to=None, + augment=True, + smooth_boundaries=False, + min_size=25): + + super().__init__() + + self.target_transform = target_transform + if target_transform == 'threeclass': + self.target_tf = ThreeclassTf(inner_distance=2) + elif target_transform == 'affinity': + self.target_tf = AffinityTf() + elif target_transform == 'stardist': + self.target_tf = StardistTf() + elif target_transform == 'cellpose': + self.target_tf = CellposeTf() + + self.tissue_type = tissue_type + self.target_type = target_type + self.data_file = data_file + self.load_data() + self.augment = augment + self.apply_smooth_boundaries = smooth_boundaries + self.min_size = min_size + self.valid_crop = 46 + + if crop_to is not None: + self.crop_fn = iaa.CropToFixedSize(width=crop_to[0], height=crop_to[1]) + else: + self.crop_fn = None + + self.batch_augmentation_fn = get_augmentation_transform() + + def __len__(self): + if not hasattr(self, "_length"): + self._length = self.raw_data.shape[0] + return self._length + + def load_data(self): + with np.load(self.data_file) as data: + if self.tissue_type is None: + self.raw_data = data['X'] + self.gt_data = data['y'] + else: + mask = data["tissue_list"] == self.tissue_type + self.raw_data = data['X'][mask] + self.gt_data = data['y'][mask] + + def augment_batch(self, raw, gtseg, idx): + # prepare images for iaa transform + + raw, gtseg = self.pre_process(raw, gtseg, idx) + gtseg = gtseg[None] # CHW -> HWC + + if self.augment: + # make sure that segmentation does not contain negative values + min_id = gtseg.min() + gtseg = gtseg - min_id + raw, gtseg = self.batch_augmentation_fn(image=raw, + segmentation_maps=gtseg) + gtseg = gtseg + min_id + + if self.crop_fn is not None: + # make sure that segmentation does not contain negative values + min_id = gtseg.min() + gtseg = gtseg - min_id + raw, gtseg = self.crop_fn(image=raw, + segmentation_maps=gtseg) + gtseg = gtseg + min_id + + raw = np.transpose(raw, [2, 0, 1]) # HWC -> CHW + + if gtseg.shape[-1] == 1: + gtseg = gtseg[0, ..., 0] + else: + gtseg = np.transpose(gtseg[0], [2, 0, 1]) + + return raw, gtseg + + def smooth_boundaries(self, segmentation): + segmentation = sizefilter(segmentation, self.min_size) + segmentation = smooth_boundary_fn(segmentation) + segmentation = sizefilter(segmentation, self.min_size) + return segmentation + + def get_gt(self, idx): + if self.target_type == "cell": + return self.gt_data[idx, ..., :1].astype(np.int32) + elif self.target_type == "nucleus": + return self.gt_data[idx, ..., 1:2].astype(np.int32) + else: + raise NotImplementedError(f"unknown target_type: {self.target_type}") + + def pre_process(self, raw, gt, idx): + if self.valid_crop == 0: + return raw, gt + p2d = ((self.valid_crop * 2, self.valid_crop * 2), (self.valid_crop * 2, self.valid_crop * 2), (0, 0)) + raw_padded = np.pad(raw, p2d, mode='constant') + gt_padded = np.pad(gt, p2d, mode='constant', constant_values=-1) + return raw_padded, gt_padded + + def __getitem__(self, idx): + + raw = self.raw_data[idx] + gt_segmentation = self.get_gt(idx) + if self.apply_smooth_boundaries: + gt_segmentation = self.smooth_boundaries(gt_segmentation) + raw, gt_segmentation = self.augment_batch(raw, gt_segmentation, idx) + + if self.target_transform is None: + return raw, gt_segmentation + else: + tc = self.target_tf(gt_segmentation) + return raw, tc, gt_segmentation + + +class EvenlyMixedDataset(Dataset): + def __init__(self, datasets, batch_sizes, ): + self.datasets = datasets + self.batch_sizes = batch_sizes + assert len(self.datasets) == len(self.batch_sizes) + + def __getitem__(self, idx): + batch = [] + + for ds, bs in zip(self.datasets, self.batch_sizes): + samples = [] + for _ in range(bs): + idx = random.randrange(0, len(ds)) + samples.append(ds[idx]) + + batch.append(tuple(np.stack((d[i] for d in samples), axis=0) for i in range(len(samples[0])))) + + return batch + + def __len__(self): + return max([len(_) for _ in self.datasets]) // sum(self.batch_sizes) + + +class ZarrSegmentationDataset(TissueNetDataset): + + def __init__(self, + data_file, + keys, + tissue_type=None, + target_type="cell", + target_transform=None, + corrected_instances=None, + limit_to_correction=False, + crop_to=None, + augment=True, + smooth_boundaries=False): + + self.keys = keys + self.tissue_type = tissue_type + + self.corrected_instances = corrected_instances + self.limit_to_correction = limit_to_correction and corrected_instances is not None + + super().__init__(data_file, + tissue_type=tissue_type, + target_type=target_type, + target_transform=target_transform, + crop_to=crop_to, + augment=augment, + smooth_boundaries=smooth_boundaries) + + if limit_to_correction: + self.crop_fn = CropAndSkipIgnore(self.crop_fn) + + + def pre_process(self, raw, gt, idx): + + if self.corrected_instances is None: + return raw, gt + + if self.limit_to_correction: + gt[:] = -1 + + if idx in self.corrected_instances: + correct_segmentation = self.correction_data[idx] + + if self.target_type == "cell": + correct_segmentation = correct_segmentation[..., :1].astype(np.int32) + elif self.target_type == "nucleus": + correct_segmentation = correct_segmentation[..., 1:2].astype(np.int32) + else: + raise NotImplementedError(f"unknown target_type: {self.target_type}") + + correction_mask = np.in1d(correct_segmentation.ravel(), self.corrected_instances[idx]).reshape(correct_segmentation.shape) + + if not self.limit_to_correction: + touching_instances = np.unique(gt[correction_mask]) + touching_instances = touching_instances[touching_instances > 0] + ignore_mask = np.in1d(gt.ravel(), touching_instances).reshape(gt.shape) + gt[ignore_mask] = -1 + # include gt background close to the corrected instances + close_to_correction = distance_transform_edt(correction_mask == 0) < 30 + close_background = np.logical_and(correct_segmentation == 0, close_to_correction) + gt[close_background] = 0 + gt[correction_mask] = correct_segmentation[correction_mask] + gt.max() + 1 + + raw, gt = super().pre_process(raw, gt, idx) + + return raw, gt + + def load_data(self): + zin = zarr.open(self.data_file, "r") + + selection = None + if self.tissue_type is not None: + mask = zin[self.keys["tissue_list"]][:] == self.tissue_type + selection = np.where(mask)[0] + + if self.limit_to_correction: + if selection is None: + print("selecting") + selection = np.array(list(self.corrected_instances.keys())) + else: + print("intersecting") + to_correct = np.array(list(self.corrected_instances.keys())) + selection = np.intersect1d(selection, to_correct) + + remapped_instances = {} + for i, s in enumerate(selection): + remapped_instances[i] = self.corrected_instances[s] + + self.corrected_instances = remapped_instances + + if selection is None: + self.raw_data = zin[self.keys["raw"]] + self.gt_data = zin[self.keys["gt"]] + if "correction" in self.keys: + self.correction_data = zin[self.keys["correction"]] + + else: + selection_mask = np.in1d(np.arange(len(zin[self.keys["raw"]])), selection) + self.raw_data = zin[self.keys["raw"]].get_orthogonal_selection(selection_mask) + self.gt_data = zin[self.keys["gt"]].get_orthogonal_selection(selection_mask) + if "correction" in self.keys: + self.correction_data = zin[self.keys["correction"]].get_orthogonal_selection(selection_mask) + +class ZarrRegressionDataset(Dataset): + + def __init__(self, + data_file, + keys, + tissue_type=None, + target_type="cell", + crop_to=None, + augment=True): + + super().__init__() + + self.keys = keys + self.tissue_type = tissue_type + self.target_type = target_type + self.data_file = data_file + self.load_data() + self.augment = augment + + if crop_to is not None: + self.crop_fn = iaa.CropToFixedSize(width=crop_to[0], height=crop_to[1]) + else: + self.crop_fn = None + + self.batch_augmentation_fn = get_augmentation_transform(simple=True) + + def load_data(self): + zin = zarr.open(self.data_file, "r") + if self.tissue_type is None: + self.raw_data = zin[self.keys["raw"]] + self.target_data = zin[self.keys["target"]] + else: + sel = np.where(zin[self.keys["tissue_list"]][:] == self.tissue_type)[0] + self.raw_data = zin[self.keys["raw"]].get_orthogonal_selection(sel) + self.target_data = zin[self.keys["target"]].get_orthogonal_selection(sel) + + def get_target(self, idx): + # expect TissueNet Dataset format N, W, H, C + return np.transpose(self.target_data[idx], [1, 2, 0]) + + def __len__(self): + if not hasattr(self, "_length"): + self._length = self.raw_data.shape[0] + return self._length + + def augment_batch(self, raw, target): + # prepare images for iaa transform + split_c = raw.shape[-1] + if target.ndim == 2: + target = target[..., None] + raw_target = np.concatenate([raw, target], axis=-1) + if self.augment: + raw_target = self.batch_augmentation_fn(image=raw_target) + if self.crop_fn is not None: + raw_target = self.crop_fn(image=raw_target) + raw = raw_target[..., :split_c] + target = raw_target[..., split_c:] + raw = np.transpose(raw, [2, 0, 1]) # HWC -> CHW + target = np.transpose(target, [2, 0, 1]) # HWC -> CHW + return raw, target + + def __getitem__(self, idx): + + raw = self.raw_data[idx] + tc = self.get_target(idx) + raw, tc = self.augment_batch(raw, tc) + + return raw, tc, None + + +class ZarrImageDataset(Dataset): + + def __init__(self, + data_file, + key, + crop_to=None, + augment=True): + + super().__init__() + + self.data_file = data_file + self.key = key + self.load_data() + self.augment = augment + if crop_to is not None: + self.crop_fn = iaa.CropToFixedSize(width=crop_to[0], height=crop_to[1]) + else: + self.crop_fn = None + + self.batch_augmentation_fn = get_augmentation_transform(medium=True) + + def load_data(self): + zin = zarr.open(self.data_file, "r") + img_keys = [f"{self.key}/{k}" for k in zin[self.key] if k.startswith("raw")] + print("reading ", img_keys) + self.raw_data = [zin[rk] for rk in img_keys] + + def __len__(self): + if not hasattr(self, "_length"): + self._length = sum([len(_) for _ in self.raw_data]) + return self._length + + def augment_batch(self, raw): + # prepare images for iaa transform + raw = np.transpose(raw, [1, 2, 0]) # HWC -> CHW + if self.augment: + raw = self.batch_augmentation_fn(image=raw) + if self.crop_fn is not None: + raw = self.crop_fn(image=raw) + raw = np.transpose(raw, [2, 0, 1]) # HWC -> CHW + return raw.copy() + + def __getitem__(self, idx): + f=0 + while idx >= len(self.raw_data[f]): + idx = idx - len(self.raw_data[f]) + f += 1 + + raw = self.raw_data[f][idx] + raw = self.augment_batch(raw) + + return raw, np.zeros(raw.shape[1:]).astype(np.int32) diff --git a/colocseg/evaluation.py b/colocseg/evaluation.py new file mode 100644 index 0000000..fa7e511 --- /dev/null +++ b/colocseg/evaluation.py @@ -0,0 +1,407 @@ +import argparse +import copy +import inspect +import os +from numcodecs import gzip + +import numpy as np +import skimage +import torch +from torch.nn.functional import sigmoid +import zarr +from PIL import Image +from pytorch_lightning.callbacks import Callback +from skimage.io import imsave +from sklearn.cluster import MeanShift +from deepcell_toolbox import metrics +import json +from stardist.matching import matching_dataset +from colocseg.inference import infer, affinity_segmentation, stardist_segmentation, cellpose_segmentation +from colocseg.utils import cluster_embeddings, label2color, zarr_append +from colocseg.visualizations import vis_anchor_embedding +from colocseg.metrics import segmentation_metric + + +def n(a): + out = a - a.min() + out /= out.max() + 1e-8 + return out + + +class AnchorSegmentationValidation(Callback): + + def __init__(self, run_segmentation=False, device='cpu'): + self.run_segmentation = run_segmentation + self.device = device + self.metrics = metrics.Metrics('colocseg') + super().__init__() + + def on_validation_epoch_start(self, trainer, pl_module): + """Called when the validation loop begins.""" + self.seg_scores = {} + + def predict_embedding(self, x, pl_module): + + with torch.no_grad(): + embedding_spatial = pl_module.forward(x.to(pl_module.device)) + + return embedding_spatial + + def create_eval_dir(self, pl_module): + eval_directory = os.path.abspath(os.path.join(pl_module.logger.log_dir, + os.pardir, + os.pardir, + "evaluation", + f"{pl_module.global_step:08d}")) + + os.makedirs(eval_directory, exist_ok=True) + return eval_directory + + def visualize_embeddings(self, embedding, x, filename): + if embedding is None: + return + + for b in range(len(embedding)): + e = embedding[b].cpu().numpy() + for c in range(0, embedding.shape[1], 2): + raw = x[b, 0].cpu().numpy() + if raw.shape[-1] != e.shape[-1]: + pd0 = (raw.shape[-1] - e.shape[-1]) // 2 + pd1 = (raw.shape[-2] - e.shape[-2]) // 2 + raw = raw[..., pd1:-pd1, pd0:-pd0] + imsave(f"{filename}_{b}_{c}.jpg", + np.stack((n(e[c]), n(raw), n(e[c + 1])), axis=-1)) + + def visualize_segmentation(self, seg, x, filename): + colseg = label2color(seg).transpose(1, 2, 0) + img = np.repeat(x[0, ..., None].cpu().numpy(), 3, axis=-1) + blend = (colseg[..., :3] / 2) + img + imsave(filename, blend) + + def visualize_embedding_vectors(self, embedding_relative, x, filename, downsample_factor=8): + + for b in range(len(embedding_relative)): + e = embedding_relative[b].cpu().numpy() + + cx = np.arange(e.shape[-2], dtype=np.float32) + cy = np.arange(e.shape[-1], dtype=np.float32) + coords = np.meshgrid(cx, cy, copy=True) + coords = np.stack(coords, axis=-1) + e_transposed = np.transpose(e, (1, 2, 0)) + + dsf = downsample_factor + spatial_dims = coords.shape[-1] + patch_coords = coords[dsf // 2::dsf, dsf // 2::dsf].reshape(-1, spatial_dims) + patch_embedding = e_transposed[dsf // 2::dsf, dsf // 2::dsf, :spatial_dims].reshape(-1, spatial_dims) + + vis_anchor_embedding(patch_embedding, + patch_coords, + n(x[b].cpu().numpy()), + grad=None, + output_file=f"{filename}_{b}.jpg") + + def write_embedding_to_file(self, eval_data_file, embedding, embedding_relative, x, y): + # write embedding and raw data to file + z_array = zarr.open(eval_data_file, mode="w") + for b in range(len(embedding_relative)): + z_array.create_dataset(f"{b}/embedding", data=embedding_relative[b].cpu().numpy(), compression='gzip') + z_array.create_dataset(f"{b}/embedding_abs", data=embedding.cpu().numpy()[b], compression='gzip') + z_array.create_dataset(f"{b}/gt_segmentation", data=y.cpu().numpy()[b, None], compression='gzip') + z_array.create_dataset(f"{b}/raw", data=x[b].cpu().numpy(), compression='gzip') + threshold = skimage.filters.threshold_li(image=x[b].cpu().numpy()) + z_array.create_dataset(f"{b}/threshold_li", data=255 * + (x[b].cpu().numpy() > threshold).astype(dtype=np.uint8), compression='gzip') + + return z_array + + def meanshift_segmentation(self, embedding, bandwidths): + # Compute Meanshift Segmentation + ms_segmentaiton = {} + + for i, bandwidth in enumerate(bandwidths): + ms_segmentaiton[bandwidth] = [] + ms_seg = segment_with_meanshift(embedding, bandwidth) + + for b, seg in enumerate(ms_seg): + ms_segmentaiton[bandwidth].append(seg) + + return ms_segmentaiton + + def visualizalize_segmentation_dict(self, segmentation_dict, x, filename): + for k in segmentation_dict: + for b, seg in enumerate(segmentation_dict[k]): + print(f'writing {filename}_{b}_{k}.jpg') + self.visualize_segmentation(seg, x[b], f'{filename}_{b}_{k}.jpg') + + def visualize_all(self, eval_directory, x, embedding_spatial, batch_idx, pl_module): + vis_pointer_filename = f"{eval_directory}/pointer_embedding_{batch_idx}_{pl_module.local_rank}" + vis_inst_embedding_filename = f"{eval_directory}/instance_embedding_{batch_idx}_{pl_module.local_rank}" + vis_relembedding_filename = f"{eval_directory}/spatial_embedding_{batch_idx}_{pl_module.local_rank}" + self.visualize_embedding_vectors(embedding_spatial, x, vis_pointer_filename) + self.visualize_embeddings(embedding_spatial, x, vis_relembedding_filename) + + def full_evaluation(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + + x, anchor_coordinates, refernce_coordinates, y = batch + + eval_directory = self.create_eval_dir(pl_module) + + embedding_spatial = self.predict_embedding(x, + pl_module) + if batch_idx < 32: + self.visualize_all(eval_directory, x, embedding_spatial, batch_idx, pl_module) + + cut = (y.shape[-1] - embedding_spatial.shape[-1]) // 2 + gt = y.cpu()[..., cut:-cut, cut:-cut].numpy() + self.gt_segs.append(gt) + + eval_data_file = f"{eval_directory}/embedding_{pl_module.local_rank}.zarr" + z_array = zarr.open(eval_data_file, mode="a") + z_array.create_dataset(f"embedding/{batch_idx}", data=embedding_spatial.cpu().numpy(), compression='gzip') + z_array.create_dataset(f"raw/{batch_idx}", data=x[..., cut:-cut, cut:-cut].cpu().numpy(), compression='gzip') + + embedding_spatial = embedding_spatial.cpu() + embedding_spatial[:, 1] += torch.arange(embedding_spatial.shape[2])[None, :, None] + embedding_spatial[:, 0] += torch.arange(embedding_spatial.shape[3])[None, None, :] + + if self.run_segmentation: + ms_bandwidths = (8,) + ms_segmentations = self.meanshift_segmentation(embedding_spatial, ms_bandwidths) + + for k in ms_segmentations: + if f"meanshift_{k}" not in self.pred_segs: + self.pred_segs[f"meanshift_{k}"] = [] + self.pred_segs[f"meanshift_{k}"].append(np.stack(ms_segmentations[k])) + + for eps in [0.1, 0.25, 0.5, 1., 2, 4, 8]: + clusters = cluster_embeddings(embedding_spatial, eps=1.) + # add 1 to set the label -1 indicating that no cluster was found to label=0 + clusters += 1 + if f"dbscan_{eps}" not in self.pred_segs: + self.pred_segs[f"dbscan_{eps}"] = [] + self.pred_segs[f"dbscan_{eps}"].append(clusters) + + if batch_idx == 0: + for i, seg in enumerate(clusters): + pred = label2color(seg).transpose(1, 2, 0)[..., :3] + gt_vis = label2color(gt[i]).transpose(1, 2, 0)[..., :3] + raw = x[i].cpu().repeat_interleave(2, dim=0)[:3, cut:-cut, cut:-cut].numpy().transpose(1, 2, 0) + imsave(f"{eval_directory}/instances_{batch_idx}_{i:03}.png", + np.concatenate((raw, pred, gt_vis), axis=0)) + imsave(f"{eval_directory}/instances_{batch_idx}_{i:03}_raw.png", raw) + imsave(f"{eval_directory}/instances_{batch_idx}_{i:03}_pred.png", pred) + imsave(f"{eval_directory}/instances_{batch_idx}_{i:03}_gt.png", gt_vis) + + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + self.full_evaluation(trainer, pl_module, batch, batch_idx, dataloader_idx) + + def on_validation_start(self, trainer, pl_module): + # reset segmentation values + self.gt_segs = [] + self.pred_segs = {} + + eval_directory = self.create_eval_dir(pl_module) + eval_data_file = f"{eval_directory}/embedding_{pl_module.local_rank}.zarr" + zarr.open(eval_data_file, mode="w") + + def on_validation_end(self, trainer, pl_module): + gt = np.concatenate(self.gt_segs, axis=0) + + eval_directory = self.create_eval_dir(pl_module) + eval_data_file = f"{eval_directory}/embedding_{pl_module.local_rank}.zarr" + z_array = zarr.open(eval_data_file, mode="a") + z_array.create_dataset(f"gt", data=gt, compression='gzip') + + for k in self.pred_segs: + predictions = np.concatenate(self.pred_segs[k], axis=0) + z_array.create_dataset(f"predictions_{k}", data=predictions, compression='gzip') + + all_metrics = self.metrics.calc_object_stats(gt, predictions) + all_metrics.to_csv(f"{eval_directory}/score_{k}.csv") + + +def segment_with_meanshift(embedding, + bandwidth, + mask=None, + reduction_probability=0.1, + cluster_all=False): + ams = AnchorMeanshift(bandwidth, + reduction_probability=reduction_probability, + cluster_all=cluster_all) + return ams(embedding, mask=mask) + 1 + + +class AnchorMeanshift(): + def __init__(self, bandwidth, reduction_probability=0.1, cluster_all=False): + self.ms = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all) + self.reduction_probability = reduction_probability + + def compute_ms(self, X): + if self.reduction_probability < 1.: + X_reduced = X[np.random.rand(len(X)) < self.reduction_probability] + ms_seg = self.ms.fit(X_reduced) + else: + ms_seg = self.ms.fit(X) + + ms_seg = self.ms.predict(X) + + return ms_seg + + def compute_masked_ms(self, embedding, mask=None): + c, w, h = embedding.shape + if mask is not None: + assert len(mask.shape) == 2 + if mask.sum() == 0: + return -1 * np.ones(mask.shape, dtype=np.int32) + resh_emb = embedding.permute(1, 2, 0)[mask].view(-1, c) + else: + resh_emb = embedding.permute(1, 2, 0).view(w * h, c) + resh_emb = resh_emb.contiguous().numpy() + + ms_seg = self.compute_ms(resh_emb) + if mask is not None: + ms_seg_spatial = -1 * np.ones(mask.shape, dtype=np.int32) + ms_seg_spatial[mask] = ms_seg + ms_seg = ms_seg_spatial + else: + ms_seg = ms_seg.reshape(w, h) + return ms_seg + + def __call__(self, embedding, mask=None): + segmentation = [] + for j in range(len(embedding)): + mask_slice = mask[j] if mask is not None else None + ms_seg = self.compute_masked_ms(embedding[j], mask=mask_slice) + segmentation.append(ms_seg) + + return np.stack(segmentation) + + +class SegmentationValidation(Callback): + + def __init__(self, name, log_dir=None, min_size=30): + super().__init__() + self.min_size = min_size + self.log_dir = log_dir + self.name = name + + def on_validation_epoch_start(self, trainer, pl_module): + """Called when the validation loop begins.""" + self.scores = {} + self.out_dir = self.create_eval_dir(pl_module) + + def create_eval_dir(self, pl_module): + eval_directory = os.path.abspath(os.path.join(pl_module.logger.log_dir, + os.pardir, + os.pardir, + "evaluation", + f"{pl_module.global_step:08d}")) + + os.makedirs(eval_directory, exist_ok=True) + self.output_file = f"{eval_directory}/evaluation_{pl_module.local_rank}.zarr" + # create a new zarr file (remove if it exists already) + zarr.open(self.output_file, "w") + return eval_directory + + def on_validation_end(self, trainer, pl_module): + eval_score_file = f"{self.out_dir}/scores_{pl_module.local_rank}.csv" + scores = evaluate_predicted_zarr(self.output_file, eval_score_file) + pl_module.metrics = scores.to_dict() + + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + raise NotImplementedError() + + +class AffinitySegmentationValidation(SegmentationValidation): + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + network_prediction = infer(batch, pl_module, pl_module.valid_crop, sigmoid=True) + + raw = batch[0].cpu().numpy() + mws_segmentation_supervised = affinity_segmentation( + raw, network_prediction, min_size=self.min_size) + + outzarr = zarr.open(self.output_file, "a") + zarr_append("raw", raw, outzarr) + zarr_append("gt", batch[1].cpu().numpy(), outzarr) + zarr_append("prediction", network_prediction, outzarr) + zarr_append("segmentation", mws_segmentation_supervised, outzarr, attr=("name", self.name)) + + +class StardistSegmentationValidation(SegmentationValidation): + + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + network_prediction = infer(batch, pl_module, pl_module.valid_crop) + + raw = batch[0].cpu().numpy() + network_prediction[:, 0] = torch.from_numpy(network_prediction[:, 0]).sigmoid().numpy() + + mws_segmentation_supervised = stardist_segmentation( + raw, network_prediction, min_size=self.min_size) + + outzarr = zarr.open(self.output_file, "a") + zarr_append("raw", raw, outzarr) + zarr_append("gt", batch[1][None].cpu().numpy(), outzarr) + if batch_idx == 0: + zarr_append("prediction", network_prediction, outzarr) + zarr_append("segmentation", mws_segmentation_supervised[None], outzarr, attr=("name", self.name)) + + +class CellposeSegmentationValidation(SegmentationValidation): + + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + network_prediction = infer(batch, pl_module, pl_module.valid_crop) + + raw = batch[0].cpu().numpy() + segmentation_supervised = cellpose_segmentation( + raw, network_prediction, min_size=self.min_size) + + outzarr = zarr.open(self.output_file, "a") + zarr_append("raw", raw, outzarr) + zarr_append("gt", batch[1][None].cpu().numpy(), outzarr) + if batch_idx == 0: + zarr_append("prediction", network_prediction, outzarr) + zarr_append("segmentation", segmentation_supervised[None], outzarr, attr=("name", self.name)) + + +def evaluate_predicted_zarr(zarr_file, + score_file_out, + gt_key="gt", + seg_key="segmentation", + hide_report=True, + progbar=False, + mask=False): + # compute all metrics + met = metrics.Metrics('colocseg') + zarrin = zarr.open(zarr_file) + if mask is None: + gt = zarrin[gt_key] + predictions = zarrin[seg_key] + else: + gt = zarrin[gt_key][mask] + predictions = zarrin[seg_key][mask] + + if hide_report: + met.print_object_report = lambda a: None + + object_metrics = met.calc_object_stats(gt, predictions, progbar=progbar) + scores = segmentation_metric(gt, predictions, return_matches=True) + + with open(score_file_out[:-4] + "_matches.json", 'w') as file: + json.dump(scores, file) + + object_metrics.to_csv(score_file_out) + obj_mean = object_metrics.mean(axis=0) + obj_mean["seg_w"] = scores["seg"] + obj_mean["recall_w"] = scores["recall"] + obj_mean["precision_w"] = scores["precision"] + obj_mean["f1_w"] = scores["f1"] + obj_mean["sum_gt_objects"] = scores["sum_gt_objects"] + obj_mean["sum_pred_objects"] = scores["sum_pred_objects"] + + taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + for tau in taus: + stats = matching_dataset(gt, predictions, thresh=tau, show_progress=False)._asdict() + for k in stats: + obj_mean[f"std_{k}_{tau:0.1f}"] = stats[k] + + return obj_mean diff --git a/colocseg/infer_pseudo_gt_from_mean_std.py b/colocseg/infer_pseudo_gt_from_mean_std.py new file mode 100644 index 0000000..b299e8e --- /dev/null +++ b/colocseg/infer_pseudo_gt_from_mean_std.py @@ -0,0 +1,77 @@ +from scipy.ndimage import gaussian_filter +from colocseg.evaluation import segment_with_meanshift +from colocseg.utils import zarr_append, remove_border, sizefilter +from colocseg.evaluation import evaluate_predicted_zarr +from colocseg.evaluation import * +import zarr +import numpy as np +import torch +import sys +import stardist +from colocseg.utils import remove_border +from skimage.filters import threshold_otsu + +# load miniunet +result_dataset = sys.argv[1] +input_dataset = sys.argv[2] +input_name = sys.argv[3] +result_name = sys.argv[4] +t = int(sys.argv[5]) + +th = sys.argv[6] +if th != "otsu": + th = float(th) + +zin = zarr.open(result_dataset, "r") +zout = zarr.open(result_dataset, "a") + +def msseg(emb, raw, ac, bw, th=0.21): + emb = torch.from_numpy(emb) + emb[:, 1] += torch.arange(emb.shape[2])[None, :, None] + emb[:, 0] += torch.arange(emb.shape[3])[None, None, :] + + if th == "otsu": + mask = (threshold_otsu(ac) > ac) + mask = mask[None] + elif th == 0: + mask = None + else: + ac = ac.copy() + ac -= ac.min() + ac /= ac.max() + mask = (th > ac) + mask = mask[None] + + seg = segment_with_meanshift(emb, + bw, + mask=mask, + reduction_probability=0.1, + cluster_all=False)[0] + seg = stardist.fill_label_holes(seg) + seg = remove_border(raw[..., 0], + raw[..., 1], + seg[None])[0] + seg = sizefilter(seg, 10) + return seg + + +if __name__ == "__main__": + + + ameanstd = zin[input_name] + with np.load(input_dataset) as data: + raw = data['X'][t:t + 1] + + ac = ameanstd[t, 2] + bw = 7 + emb = ameanstd[t:t + 1, :2] + meanmeanshift_seg = msseg(emb, raw, ac, bw, th=th) + w, h = meanmeanshift_seg.shape + + out_key = result_name + outarr = zout.require_dataset(out_key, + shape=(raw.shape[0], w, h, 1), + chunks=(1, w, h, 1), + compression='gzip', + dtype=np.int32) + outarr[t, ..., 0] = meanmeanshift_seg diff --git a/colocseg/infer_segmentation.py b/colocseg/infer_segmentation.py new file mode 100644 index 0000000..e2860d4 --- /dev/null +++ b/colocseg/infer_segmentation.py @@ -0,0 +1,72 @@ +import argparse +import sys + +import numpy as np +import torch +import zarr + +from colocseg.evaluation import * +from colocseg.inference import (affinity_segmentation, infer, + stardist_segmentation) +from colocseg.model import Unet2D + +parser = argparse.ArgumentParser() +parser.add_argument("name") +parser.add_argument("checkpoint") +parser.add_argument("dataset_file") +parser.add_argument("split") +parser.add_argument("out_channels", type=int) +parser.add_argument("valid_crop", type=int) +parser.add_argument("output_file") +parser.add_argument("min_size", type=int) +parser.add_argument("max_idx", type=int) +parser.add_argument("idx", type=int) +args = parser.parse_args() + + +model_state_dict = torch.load(args.checkpoint)["model_state_dict"] +model = Unet2D(2, + args.out_channels, + depth=3) +model.load_state_dict(model_state_dict) +model = model.cuda() + +zin = zarr.open(args.dataset_file, "a") +asv = AnchorSegmentationValidation() + +with torch.no_grad(): + + raw = zin[f"{args.split}/raw"][args.idx:args.idx + 1] + raw = np.transpose(raw, (0, 3, 1, 2)) + x = torch.from_numpy(raw).cuda() + y = zin[f"{args.split}/gt"][args.idx:args.idx + 1] + y = np.transpose(y, (0, 3, 1, 2)) + + print(x.shape, y.shape) + network_prediction = infer([x, y], model, args.valid_crop) + print(network_prediction.shape) + + mws_segmentation_supervised = stardist_segmentation( + raw, network_prediction, min_size=args.min_size) + mws_segmentation_supervised = mws_segmentation_supervised[None] + + outzarr = zarr.open(args.output_file, "w") + outzarr.require_dataset(f"raw", + shape=(args.max_idx, ) + raw.shape[1:], + chunks=raw.shape, + dtype=np.float32, + compression='gzip')[args.idx:args.idx + 1] = raw + outzarr.require_dataset(f"gt", shape=(args.max_idx, ) + y.shape[1:], + chunks=y.shape, + dtype=np.int32, + compression='gzip')[args.idx:args.idx + 1] = y + outzarr.require_dataset(f"prediction", shape=(args.max_idx, ) + + network_prediction.shape[1:], + dtype=np.float32, + chunks=network_prediction.shape, + compression='gzip')[args.idx:args.idx + 1] = network_prediction + outzarr.require_dataset(f"segmentation", shape=(args.max_idx, ) + + mws_segmentation_supervised.shape[1:], + chunks=mws_segmentation_supervised.shape, + dtype=np.int32, + compression='gzip')[args.idx:args.idx + 1] = mws_segmentation_supervised diff --git a/colocseg/infer_spatial_embeddings.py b/colocseg/infer_spatial_embeddings.py new file mode 100644 index 0000000..8feb23c --- /dev/null +++ b/colocseg/infer_spatial_embeddings.py @@ -0,0 +1,99 @@ +# from colocseg.evaluation import * +import zarr +import numpy as np +import torch +import sys +from colocseg.model import Unet2D + +# load miniunet +checkpoint = sys.argv[1] +output_file = sys.argv[2] +result_name = sys.argv[3] +data_file = sys.argv[4] +idx = int(sys.argv[5]) + +if len(sys.argv) > 6: + input_key = sys.argv[6] +else: + input_key = "raw" + +if len(sys.argv) > 7: + in_channels = int(sys.argv[7]) +else: + in_channels = 2 + +if len(sys.argv) > 8: + features_in_last_layer = int(sys.argv[8]) +else: + features_in_last_layer = 64 + +transpose = True +if len(sys.argv) > 9: + transpose = sys.argv[9] == "transpose" + +model_state_dict = torch.load(checkpoint)["model_state_dict"] +model = Unet2D(in_channels, 2, num_fmaps=256, features_in_last_layer=features_in_last_layer) +model.load_state_dict(model_state_dict) +model = model.cuda() + +p_salt = 0.01 +zout = zarr.open(output_file, "a") +split_limit = 800 +step = 120 + +with torch.no_grad(): + with np.load(data_file) as data: + x = data['X'][idx:idx + 1] + if transpose: + x = np.transpose(x, (0, 3, 1, 2)) + x = np.pad(x, ((0, 0), (0, 0), (8, 8), (8, 8)), mode='constant') + print(x.shape) + clean_input = torch.from_numpy(x.astype(np.float32)).cuda() + + predictions = [] + for salt_value in [0.5, 1.0]: + for _ in range(16): + noisy_input = clean_input.detach().clone() + rnd = torch.rand(*noisy_input.shape).cuda() + noisy_input[rnd <= p_salt] = salt_value + + if x.shape[-1] > split_limit: + pred = [] + for idx_low in range(8, x.shape[-1] - 8, step): + inp = noisy_input[..., idx_low - 8:idx_low + step + 8] + pred.append(model(inp)[0].detach().cpu()) + pred = torch.cat(pred, dim=-1) + else: + pred = model(noisy_input)[0].detach().cpu() + predictions.append(pred) + + emb_std, emb = torch.std_mean(torch.stack(predictions, dim=0), dim=0, keepdim=False, unbiased=False) + emb_std = emb_std.sum(dim=0, keepdim=True) + emb_out = torch.cat((emb, emb_std), dim=0) + c, w, h = emb_out.shape + key = f'{result_name}' + with np.load(data_file) as data: + b = data['X'].shape[0] + + out_ds = zout.require_dataset(key, + shape=(b, c, w, h), + chunks=(1, c, w, h), + compression='gzip', + dtype=np.float32) + print(emb_out.shape) + out_ds[idx] = emb_out + + # emb[:, 1] += torch.arange(emb.shape[2])[None, :, None] + # emb[:, 0] += torch.arange(emb.shape[3])[None, None, :] + # seg = asv.meanshift_segmentation(emb, ms_bandwidths)[ms_bandwidths[0]][0] + # seg += 1 + # seg = remove_border(zin[f"{split}/raw"][idx:idx + 1, ..., 0], + # zin[f"{split}/raw"][idx:idx + 1, ..., 1], seg[None])[0] + # w, h = seg.shape[-2:] + # if f'{split}/{result_name}' not in zin: + # zin.create_dataset(f'{split}/{result_name}', + # shape=(zin[f"{split}/raw"].shape[0], w, h, 1), + # chunks=(1, w, h, 1), + # compression='gzip', + # dtype=np.int32) + # zin[f'{split}/{result_name}'][idx, ..., 0] = seg diff --git a/colocseg/inference.py b/colocseg/inference.py new file mode 100644 index 0000000..ee5d7db --- /dev/null +++ b/colocseg/inference.py @@ -0,0 +1,125 @@ +import numpy as np +import stardist +import torch +import torch.nn.functional as F +import cellpose +from affogato.segmentation import compute_mws_segmentation + +from cellpose.dynamics import compute_masks +from colocseg.transforms import get_offsets +from colocseg.utils import remove_border, sizefilter, sizefilter_batch + + +def crop_to_fit(tensor, target_shape): + remaining_cut_1_l = (tensor.shape[-1] - target_shape[-1]) // 2 + remaining_cut_1_r = (tensor.shape[-1] - target_shape[-1]) - remaining_cut_1_l + remaining_cut_2_l = (tensor.shape[-2] - target_shape[-2]) // 2 + remaining_cut_2_r = (tensor.shape[-2] - target_shape[-2]) - remaining_cut_2_l + return tensor[..., remaining_cut_2_l:-remaining_cut_2_r, remaining_cut_1_l:-remaining_cut_1_r] + + +def mws_segmentation(affinities, mask, seperating_channel=2, offsets=None, strides=(4, 4)): + + offsets = get_offsets() if offsets is None else offsets + attractive_repulsive_weights = affinities.copy() + attractive_repulsive_weights[:, :seperating_channel, ...] *= -1 + attractive_repulsive_weights[:, :seperating_channel, ...] += +1 + predicted_segmentation = [] + + for i in range(attractive_repulsive_weights.shape[0]): + predicted_segmentation.append(compute_mws_segmentation( + attractive_repulsive_weights[i], + offsets, + seperating_channel, + strides=strides, + randomize_strides=True, + mask=mask[i])) + return np.stack(predicted_segmentation) + + +def infer(batch, model, valid_crop, sigmoid=False): + + x, y = batch + p2d = (valid_crop * 2, valid_crop * 2, valid_crop * 2, valid_crop * 2) + x_padded = F.pad(x, p2d, mode='constant') + network_prediction = model.forward(x_padded)[-1] + network_prediction = crop_to_fit(network_prediction, y.shape) + if sigmoid: + network_prediction.sigmoid_() + + return network_prediction.cpu().numpy() + + +def affinity_segmentation(raw, network_prediction, min_size): + + segmentation = mws_segmentation( + network_prediction[:, 1:], + mask=network_prediction[:, 0] > 0.5) + + # the validation set contains a large amount of padding + # remove all segments outside of valid data area + segmentation = remove_border(raw[:, 0], raw[:, 1], segmentation) + segmentation = sizefilter_batch(segmentation, min_size=min_size) + + return segmentation.astype(np.int32) + + +def stardist_instances_from_prediction(dist, prob, prob_thresh=0.486166, nms_thresh=0.5, grid=(1, 1)): + points, probi, disti = stardist.nms.non_maximum_suppression(dist, prob, grid=grid, + prob_thresh=prob_thresh, nms_thresh=nms_thresh) + img_shape = prob.shape + return stardist.geometry.polygons_to_label(disti, points, prob=probi, shape=img_shape) + + +def batched_stardist_inference(predictions, prob_thresh=0.486166, nms_thresh=0.5, grid=(1, 1)): + + predicted_segmentation = [] + for i in range(predictions.shape[0]): + dist = np.transpose(predictions[i, 1:], (1, 2, 0)) + prob = predictions[i, 0] + assert dist.shape[-1] == 16, f"Unexpected stardist channels, dist.shape={dist.shape}" + predicted_segmentation.append( + stardist_instances_from_prediction( + dist, + prob, grid=grid, + prob_thresh=prob_thresh, nms_thresh=nms_thresh)) + + return np.stack(predicted_segmentation) + + +def stardist_segmentation(raw, network_prediction, min_size, prob_thresh=0.486166): + segmentation = batched_stardist_inference(network_prediction, prob_thresh=prob_thresh) + # the validation set contains a large amount of padding + # remove all segments outside of valid data area + segmentation = remove_border(raw[:, 0], raw[:, 1], segmentation) + segmentation = sizefilter_batch(segmentation, min_size=min_size) + + return segmentation.astype(np.int32) + +def cellpose_instances_from_prediction(network_prediction): + flow = network_prediction[:2] + distance_field = network_prediction[2] + try: + seg, _, _ = compute_masks(flow, distance_field, mask_threshold=1.0, flow_threshold=None, min_size=10) + except ValueError: + seg = np.zeros(distance_field.shape) + return seg + + +def batched_celpose_inference(predictions): + predicted_segmentation = [] + for i in range(predictions.shape[0]): + predicted_segmentation.append( + cellpose_instances_from_prediction(predictions[i])) + return np.stack(predicted_segmentation) + + +def cellpose_segmentation(raw, network_prediction, min_size): + + segmentation = batched_celpose_inference(network_prediction) + # the validation set contains a large amount of padding + # remove all segments outside of valid data area + segmentation = remove_border(raw[:, 0], raw[:, 1], segmentation) + segmentation = sizefilter_batch(segmentation, min_size=min_size) + + return segmentation.astype(np.int32) diff --git a/colocseg/loss.py b/colocseg/loss.py new file mode 100644 index 0000000..908f0bf --- /dev/null +++ b/colocseg/loss.py @@ -0,0 +1,254 @@ +import gc +import time +from time import sleep + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pylab +import torch +from inferno.extensions.criteria.set_similarity_measures import \ + SorensenDiceLoss +from matplotlib import collections as mc +from skimage.io import imsave +from torch import Tensor +from torch.nn import functional as F +from torch.nn import MSELoss, BCEWithLogitsLoss +from torch.nn.modules.module import Module +from cellpose.core import MaskedLoss, DerivativeLoss, WeightedLoss, ArcCosDotLoss, NormLoss, DivergenceLoss +from colocseg.utils import cluster_embeddings, label2color + +matplotlib.use('agg') + + +class AnchorLoss(Module): + r""" + + Args: + anchor_radius (float): The attraction of anchor points is non linear (sigmoid function). + The used nonlinear distance is sigmoid(euclidean_distance - anchor_radius) + Therefore the anchors experience a stronger attraction when their distance is + smaller than the anchor_radius + """ + + def __init__(self, temperature) -> None: + super().__init__() + self.temperature = temperature + + def distance_fn(self, e0, e1): + diff = (e0 - e1) + return diff.norm(2, dim=-1) + + def nonlinearity(self, distance): + return 1 - (-distance.pow(2) / self.temperature).exp() + + def forward(self, anchor_embedding, reference_embedding) -> Tensor: + # compute all pairwise distances of anchor embeddings and reference embedding + # reference embedding are detached, to avoid biased gradients at the image boundaries + dist = self.distance_fn(anchor_embedding, reference_embedding.detach()) + # dist.shape = (b, p, p) + nonlinear_dist = self.nonlinearity(dist) + return nonlinear_dist.sum() + + def absoute_embedding(self, embedding, abs_coords): + return embedding + abs_coords + + +class AnchorPlusContrastiveLoss(AnchorLoss): + + def __init__(self, *args) -> None: + super().__init__(*args) + self.ce = torch.nn.CrossEntropyLoss(ignore_index=-1) + self.weight = 10. + + def forward(self, embedding, contr_emb, abs_coords, patch_mask) -> Tensor: + # compute all pairwise distances of anchor embeddings + dist = self.distance_fn(embedding, abs_coords) + # dist.shape = (b, p, p) + + nonlinear_dist = self.nonlinearity(dist) + # only matched patches (e.g. patches close in proximity) + # contripute to the loss + nonlinear_dist = nonlinear_dist[patch_mask == 1] + + loss = nonlinear_dist.sum() + + if contr_emb is None: + return loss + + try: + cluster_labels = cluster_embeddings(embedding + abs_coords) + contr_emb = F.normalize(contr_emb, dim=-1) + cum_mean_clusters = [] + + for b in range(len(embedding)): + with torch.no_grad(): + mean_clusters = [contr_emb[b, cluster_labels[b] == i].mean( + axis=0) for i in np.unique(cluster_labels[b]) if i >= 0] + if len(mean_clusters) > 0: + mean_clusters = torch.stack(mean_clusters, dim=-1) + cum_mean_clusters.append(mean_clusters) + + cum_mean_clusters = torch.cat(cum_mean_clusters, dim=-1) + stacked_contr_emb = contr_emb.view(-1, cum_mean_clusters.shape[0]) + logits = torch.matmul(stacked_contr_emb, cum_mean_clusters) + target = torch.from_numpy(np.concatenate(cluster_labels, axis=0)).long().to(logits.device) + bce_loss = self.ce(logits, target) + loss += self.weight * bce_loss + except: + print("clustering failed! Returning anchor loss") + + return loss + + +class StardistLoss(torch.nn.Module): + """Loss for stardist predicsions combines BCE loss for probabilities + with MAE (L1) loss for distances + + Args: + weight: Distance loss weight. Total loss will be bce_loss + weight * l1_loss + """ + + def __init__(self, weight=1.): + + super().__init__() + self.weight = weight + + def forward(self, prediction, target, mask=None): + # Predicted distances errors are weighted by object prob + target_prob = target[:, :1] + predicted_prob = prediction[:, :1] + target_dist = target[:, 1:] + predicted_dist = prediction[:, 1:] + + if mask is not None: + target_prob = mask * target_prob + # do not train foreground prediction when mask is supplied + predicted_prob = predicted_prob.detach() + + l1loss_pp = F.l1_loss(predicted_dist, + target_dist, + reduction='none') + + ignore_mask_provided = target_prob.min() < 0 + if ignore_mask_provided: + # ignore label was supplied + ignore_mask = target_prob >= 0. + # add one to avoid division by zero + imsum = ignore_mask.sum() + if imsum == 0: + print("WARNING: Batch with only ignorelabel encountered!") + return 0*l1loss_pp.sum() + + l1loss = ((target_prob * ignore_mask) * l1loss_pp).sum() / imsum + bceloss = F.binary_cross_entropy_with_logits(predicted_prob[ignore_mask], + target_prob[ignore_mask], + reduction='sum') / imsum + return self.weight * l1loss + bceloss + + # weight predictions by target probs + l1loss = (target_prob * l1loss_pp).mean() + bceloss = F.binary_cross_entropy_with_logits(predicted_prob, + target_prob, + reduction='mean') + return self.weight * l1loss + bceloss + + +class RegressionLoss(torch.nn.Module): + """MAE (L1) regression loss""" + + def forward(self, prediction, target, mask=None): + + if mask is not None: + target = mask * target + prediction = mask * prediction + + l1loss = F.l1_loss(prediction, + target) + return l1loss + + +class AffinityLoss(torch.nn.Module): + """Loss for affiniy predicsions combines SorensenDiceLoss loss for affinities + with BCE loss for foreground background + + Args: + weight: affinity loss weight. Total loss will be bce_loss + weight * aff_loss + """ + + def __init__(self, weight=0.1): + super().__init__() + self.sd_loss = SorensenDiceLoss() + self.fgbg_loss = torch.nn.BCEWithLogitsLoss() + + def forward(self, prediction, target): + bceloss = self.fgbg_loss(prediction[:, :1], target[:, :1]) + aff_loss = self.sd_loss(prediction[:, 1:].sigmoid(), target[:, 1:]) + return self.weight * aff_loss + bceloss + + +class CellposeLoss(torch.nn.Module): + """Loss for cellpose flow predictions + adapted from https://github.com/MouseLand/cellpose + """ + + def __init__(self): + super().__init__() + self.criterion = MSELoss(reduction='mean') + self.criterion2 = BCEWithLogitsLoss(reduction='mean') + self.criterion6 = MaskedLoss() + self.criterion11 = DerivativeLoss() + self.criterion12 = WeightedLoss() + self.criterion14 = ArcCosDotLoss() + self.criterion15 = NormLoss() + self.criterion16 = DivergenceLoss() + + def forward(self, prediction, target): + """ Loss function for Omnipose. + + Parameters + -------------- + target: ND-array, float + transformed labels in array [nimg x nchan x xy[0] x xy[1]] + target[:,0] distance fields + target[:,1:3] flow fields + target[:,3] boundary fields + target[:,4] boundary-emphasized weights + + prediction: ND-tensor, float + network predictions + prediction[:,:2] flow fields + prediction[:,2] distance fields + prediction[:,3] boundary fields + + """ + + veci = target[:,1:3] + dist = target[:,0] + boundary = target[:,3] + cellmask = dist>0 + w = target[:,4] + dist = dist + boundary = boundary + cellmask = cellmask.bool() + + flow = prediction[:,:2] # 0,1 + dt = prediction[:,2] + bd = prediction[:,3] + a = 10. + + wt = torch.stack((w,w),dim=1) + ct = torch.stack((cellmask,cellmask),dim=1) + + loss1 = 10.*self.criterion12(flow,veci,wt) #weighted MSE + loss2 = self.criterion14(flow,veci,w,cellmask) #ArcCosDotLoss + loss3 = self.criterion11(flow,veci,wt,ct)/a # DerivativeLoss + loss4 = 2.*self.criterion2(bd,boundary) + loss5 = 2.*self.criterion15(flow,veci,w,cellmask) # loss on norm + loss6 = 2.*self.criterion12(dt,dist,w) #weighted MSE + loss7 = self.criterion11(dt.unsqueeze(1), + dist.unsqueeze(1), + w.unsqueeze(1), + cellmask.unsqueeze(1))/a + + return loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7 diff --git a/colocseg/loss_supervised.py b/colocseg/loss_supervised.py new file mode 100644 index 0000000..1534b51 --- /dev/null +++ b/colocseg/loss_supervised.py @@ -0,0 +1,80 @@ +from torch.nn.modules.module import Module +import torch +from torch import Tensor +from torch.nn import functional as F +import gc +from time import sleep + +import matplotlib +matplotlib.use('agg') +from matplotlib import collections as mc +import pylab +import matplotlib.pyplot as plt +import time +from skimage.io import imsave + +class SupervisedInstanceEmbeddingLoss(Module): + def __init__(self, push_margin): + super().__init__() + self.push_margin = push_margin + + def pull_distance(self, x, y, dim_channels, dim_samples): + return (x - y).norm(p=2, dim=dim_channels).mean(dim=dim_samples) + + def push_distance_measure(self, x, y, dim_channels): + return (self.push_margin - (x-y).norm(p=2, dim=dim_channels)).relu_() + + def push_distance(self, centroids, dim_channels, dim_samples): + assert centroids.dim() == 2 + distance_matrix = self.push_distance_measure( + centroids.unsqueeze(dim_samples), + centroids.unsqueeze(dim_samples+1), + dim_channels=-1, + ) + # select vectorized upper triangle of distance matrix + n_clusters = distance_matrix.shape[0] + upper_tri_index = torch.arange(1, n_clusters * n_clusters + 1) \ + .view(n_clusters, n_clusters) \ + .triu(diagonal=1).nonzero().transpose(0, 1) + cluster_distances = distance_matrix[upper_tri_index[0], upper_tri_index[1]] + + return cluster_distances.mean() + + def forward(self, abs_embedding, coordinates, y, split_pull_push=False): + pull_loss = torch.tensor(0.).to(abs_embedding.device) + push_loss = torch.tensor(0.).to(abs_embedding.device) + + for b in range(len(y)): + cx = coordinates[b, :, 1].long() + cy = coordinates[b, :, 0].long() + + y_per_patch = y[b, cx, cy] + centroids = [] + dim_channels, dim_samples = 1, 0 + pull_over_instances = 0 + + for idx in torch.unique(y_per_patch): + patch_mask = y_per_patch == idx + if idx == 0: + # skip background instance + continue + + instance_embedding = abs_embedding[b, patch_mask] + + centroid = instance_embedding.mean(dim=dim_samples, + keepdim=True) + centroids.append(centroid) + pull_over_instances = pull_over_instances + \ + self.pull_distance( + centroid, instance_embedding, dim_channels, dim_samples) + + # add push loss between centroids + if len(centroids) > 1: + pull_loss = pull_loss + (pull_over_instances / len(centroids)) + push_loss = push_loss + self.push_distance(torch.cat(centroids, dim=0), + dim_channels, dim_samples) + + if split_pull_push: + return pull_loss, push_loss + else: + return pull_loss + push_loss diff --git a/colocseg/metrics.py b/colocseg/metrics.py new file mode 100644 index 0000000..5cc7ae1 --- /dev/null +++ b/colocseg/metrics.py @@ -0,0 +1,73 @@ +from scipy.stats import hmean +import numpy as np + +def segmentation_metric(gt_label, res_label, overlap_threshold=0.5, match_iou=0.5, return_matches=False): + + seg = 0. + n_matches = 0 + counter = 0 + imgCounter = 0 + assert(overlap_threshold >= 0.5) + matches = {} + + compare_dtype = np.dtype([('res', res_label.dtype), ('gt', gt_label.dtype)]) + + sum_iou = 0 + sum_gt_objects = 0 + sum_pred_objects = 0 + + for t in range(len(res_label)): + + label_tuples = np.empty(res_label[t].shape, dtype=compare_dtype) + label_tuples['res'] = res_label[t] + label_tuples['gt'] = gt_label[t] + + both_foreground = np.logical_and(label_tuples['res'] > 0, label_tuples['gt'] > 0) + index_pairs, intersections = np.unique(label_tuples[both_foreground], return_counts=True) + gt_indexes, gt_size = np.unique(label_tuples['gt'][label_tuples['gt'] > 0], return_counts=True) + sum_gt_objects += len(gt_indexes) + sum_pred_objects += len(np.unique(np.unique(label_tuples['res'][label_tuples['res'] > 0]))) + + if return_matches: + for gt_idx in gt_indexes: + matches[(t, gt_idx)] = (0., 0) + + for (res_idx, gt_idx), intersection in zip(index_pairs, intersections): + gt_size = (label_tuples['gt'] == gt_idx).sum() + res_size = (label_tuples['res'] == res_idx).sum() + overlap = intersection / gt_size + if overlap > overlap_threshold: + iou = intersection / (gt_size + res_size - intersection) + sum_iou += iou + if return_matches: + matches[(t, gt_idx)] = (iou, res_idx) + if iou > match_iou: + n_matches += 1 + + scores = {} + + scores["sum_gt_objects"] = sum_gt_objects + scores["sum_pred_objects"] = sum_pred_objects + + if sum_gt_objects == 0: + scores["seg"] = 0 + recall = 0 + else: + recall = n_matches / sum_gt_objects + scores["seg"] = sum_iou / sum_gt_objects + + if sum_pred_objects == 0: + precision = 0 + else: + precision = n_matches / sum_pred_objects + + scores["recall"] = recall + scores["precision"] = precision + scores["f1"] = hmean([recall, precision]) + if return_matches: + scores["matched_iou"] = [float(v[0]) for v in matches.values()] + scores["matched_res_idx"] = [int(v[1]) for v in matches.values()] + scores["matched_t"] = [int(k[0]) for k in matches.keys()] + scores["matched_gt_idx"] = [int(k[1]) for k in matches.keys()] + + return scores diff --git a/colocseg/model.py b/colocseg/model.py new file mode 100644 index 0000000..3a1d651 --- /dev/null +++ b/colocseg/model.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +from funlib.learn.torch.models import UNet + + +class Unet2D(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + aux_channels=0, + num_fmaps=64, + fmap_inc_factor=3, + features_in_last_layer=64, + head_type="single", + depth=1): + + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.aux_channels = aux_channels + self.features_in_last_layer = features_in_last_layer + self.head_type = head_type + + d_factors = [(2, 2), ] * depth + self.backbone = UNet(in_channels=self.in_channels, + num_fmaps=num_fmaps, + fmap_inc_factor=fmap_inc_factor, + downsample_factors=d_factors, + activation='ReLU', + batch_norm=False, + padding='valid', + num_fmaps_out=self.features_in_last_layer, + kernel_size_down=[[(3, 3), (1, 1), (1, 1), (3, 3)]] * (depth + 1), + kernel_size_up=[[(3, 3), (1, 1), (1, 1), (3, 3)]] * depth, + constant_upsample=True) + + # Commonly used Non-linear projection head + # see https://arxiv.org/pdf/1906.00910.pdf + # or https://arxiv.org/pdf/2002.05709.pdf + if head_type == "single": + self.head = torch.nn.Sequential(nn.Conv2d(self.features_in_last_layer, self.features_in_last_layer, 1), + nn.ReLU(), + nn.Conv2d(self.features_in_last_layer, out_channels + aux_channels, 1)) + elif head_type == "seq": + self.head_pre = torch.nn.Sequential(nn.Conv2d(self.features_in_last_layer, self.features_in_last_layer, 1), + nn.ReLU(), + nn.Conv2d(self.features_in_last_layer, self.features_in_last_layer + aux_channels, 1)) + self.head_post = torch.nn.Sequential(nn.Conv2d(self.features_in_last_layer + aux_channels, self.features_in_last_layer, 1), + nn.ReLU(), + nn.Conv2d(self.features_in_last_layer, out_channels, 1)) + elif head_type == "multi": + self.head_main = torch.nn.Sequential(nn.Conv2d(self.features_in_last_layer, self.features_in_last_layer, 1), + nn.ReLU(), + nn.Conv2d(self.features_in_last_layer, out_channels, 1)) + self.head_aux = torch.nn.Sequential(nn.Conv2d(self.features_in_last_layer, self.features_in_last_layer, 1), + nn.ReLU(), + nn.Conv2d(self.features_in_last_layer, aux_channels, 1)) + + def head_forward(self, last_layer_output): + if self.head_type == "single": + out_cat = self.head(last_layer_output) + if self.aux_channels == 0: + return out_cat + else: + return out_cat[:, self.out_channels:], out_cat[:, :self.out_channels] + elif self.head_type == "seq": + pre_out = self.head_pre(last_layer_output) + out = self.head_post(pre_out) + return pre_out[:, :self.aux_channels], out + elif self.head_type == "multi": + out = self.head_main(last_layer_output) + out_aux = self.head_aux(last_layer_output) + return out_aux, out + + + @staticmethod + def select_and_add_coordinates(output, coords): + selection = [] + # output.shape = (b, c, h, w) + for o, c in zip(output, coords): + sel = o[:, c[:, 1], c[:, 0]] + sel = sel.transpose(1, 0) + sel += c + selection.append(sel) + + # selection.shape = (b, c, p) where p is the number of selected positions + return torch.stack(selection, dim=0) + + def forward(self, raw): + h = self.backbone(raw) + return self.head_forward(h) + + def forward_and_select(self, raw, coords): + # coords.shape = (b, p, 2) + h = self.forward(raw) + return self.select_coords(h, coords) diff --git a/colocseg/test_inference.py b/colocseg/test_inference.py new file mode 100644 index 0000000..71ab699 --- /dev/null +++ b/colocseg/test_inference.py @@ -0,0 +1,145 @@ +import sys + +from argparse import ArgumentParser +from colocseg.utils import zarr_append +from tqdm import tqdm +import zarr + +from colocseg.inference import infer, stardist_segmentation, cellpose_segmentation +import torch +from glob import glob +from colocseg.model import Unet2D +import shlex +import pandas as pd +from colocseg.evaluation import evaluate_predicted_zarr +import pytorch_lightning as pl +from colocseg.datasets import TissueNetDataset +from argparse import ArgumentParser + +tissuenet_test_dataset = "path_to_ds/tissuenet_v1.0_test.npz" + +def test_from_checkpoint(root_folder, args, checkpoint_index="all", target_type="cell", tissue_type="immune"): + + cp_file = glob(f"{root_folder}/models/maxi_*.torch") + cp_file.sort(reverse=True) + if checkpoint_index != "all": + idx = int(checkpoint_index) + print("testing model index ", idx, " of ", cp_file) + if idx == -1: + cp_file = cp_file[-1:] + else: + cp_file = cp_file[idx:idx + 1] + print(cp_file) + + for checkpoint in cp_file: + model_state_dict = torch.load(checkpoint)["model_state_dict"] + model = Unet2D(args.in_channels, + args.out_channels, + args.aux_channels, + head_type=args.unet_head_type, + fmap_inc_factor=args.unet_fmap_inc_factor, + num_fmaps=32, + depth=3, + ) + model.load_state_dict(model_state_dict) + model = model.cuda().eval() + + ds = TissueNetDataset( + tissuenet_test_dataset, + augment=False, + tissue_type=tissue_type, + target_type=target_type, + ) + ds.valid_crop = 0 + + tissue_type_name = tissue_type if tissue_type is not None else "all" + + iteration = int(checkpoint.split("_")[-2]) + zout_file = f"{root_folder}/test_{target_type}_{tissue_type_name}_{iteration}.zarr" + zarr.open(zout_file, "w") + + for i, batch in tqdm(enumerate(ds)): + + with torch.no_grad(): + batch = torch.from_numpy(batch[0]).cuda()[None], torch.from_numpy(batch[1]).cuda()[None] + + network_prediction = infer(batch, model, args.valid_crop+1) + raw = batch[0].cpu().numpy() + + if args.loss_name_super == 'StardistLoss': + network_prediction[:, 0] = torch.from_numpy(network_prediction[:, 0]).sigmoid().numpy() + elif args.loss_name_super == 'CellposeLoss': + network_prediction[:, 3] = torch.from_numpy(network_prediction[:, 3]).sigmoid().numpy() + + outzarr = zarr.open(zout_file, "a") + zarr_append("gt", batch[1][None].cpu().numpy(), outzarr) + if i < 4: + zarr_append("raw", raw, outzarr) + zarr_append("prediction", network_prediction, outzarr) + + if args.loss_name_super == 'StardistLoss': + for prob_thresh in [0.3, 0.4, 0.486166]: + sd_segmentation_supervised = stardist_segmentation( + raw, network_prediction, min_size=30, prob_thresh=prob_thresh) + + if prob_thresh == 0.486166: + zarr_append("segmentation", sd_segmentation_supervised[None], outzarr, attr=( + "model", checkpoint, "th", prob_thresh)) + else: + zarr_append(f"segmentation_{prob_thresh:0.3f}", + sd_segmentation_supervised[None], outzarr, attr=("model", checkpoint, "th", prob_thresh)) + elif args.loss_name_super == 'CellposeLoss': + sd_segmentation_supervised = cellpose_segmentation( + raw, network_prediction, min_size=30) + zarr_append("segmentation", sd_segmentation_supervised[None], outzarr, attr=( + "model", checkpoint, "method", "cellpose")) + else: + raise NotImplementedError("loss name not recognized") + + for k in [k for k in zarr.open(zout_file, "r").keys() if k.startswith("segmentation")]: + scores = evaluate_predicted_zarr( + zout_file, f"{root_folder}/test_scores_{target_type}_{tissue_type_name}_individual_{iteration}_{k}.csv", + seg_key=k) + + outdict = {key: value for key, value in vars(args).items() if key[:1] != "_"} + outdict.update(scores.to_dict()) + outdict["iteration"] = iteration + outdict["tissue_type_test"] = tissue_type + + df = pd.DataFrame.from_dict(outdict, orient="index") + df.T.to_csv(f"{root_folder}/test_data_{target_type}_{tissue_type_name}_{iteration}_{k}.csv") + + +if __name__ == "__main__": + + root_folder = sys.argv[1] + sys.path.append(root_folder) + from colocseg.trainingmodules import PartiallySupervisedTrainer + from colocseg.datamodules import PartiallySupervisedDataModule + + if sys.argv[3] != "alltypes": + tissue_type = sys.argv[3] + tissue_type_name = tissue_type + else: + tissue_type = None + tissue_type_name = "alltypes" + + target_type = sys.argv[4] + + parser = ArgumentParser() + parser = PartiallySupervisedTrainer.add_model_specific_args(parser) + parser = pl.Trainer.add_argparse_args(parser) + parser = PartiallySupervisedDataModule.add_argparse_args(parser) + parser = PartiallySupervisedDataModule.add_model_specific_args(parser) + + sc_file = f"{root_folder}/train.sh" + + with open(sc_file, 'r') as f: + lines = f.read().splitlines() + call_string_pre_split_eq = shlex.split(lines[-1])[2:] + + args = parser.parse_args(call_string_pre_split_eq) + print(args.__dict__.items()) + checkpoint_index = sys.argv[2] + + test_from_checkpoint(root_folder, args, checkpoint_index) diff --git a/colocseg/train_ssl.py b/colocseg/train_ssl.py new file mode 100644 index 0000000..3bb204f --- /dev/null +++ b/colocseg/train_ssl.py @@ -0,0 +1,36 @@ +from argparse import ArgumentParser +from colocseg.datamodules import TissueNetDataModule +from colocseg.evaluation import AnchorSegmentationValidation +from colocseg.trainingmodules import SSLTrainer +from colocseg.utils import SaveModelOnValidation +from pytorch_lightning.callbacks import LearningRateMonitor +import pytorch_lightning as pl + + +if __name__ == '__main__': + + parser = ArgumentParser() + + pl.utilities.seed.seed_everything(42) + parser = SSLTrainer.add_model_specific_args(parser) + parser = pl.Trainer.add_argparse_args(parser) # , logger=WandbLogger(project='SSLAnchor')) + parser = TissueNetDataModule.add_argparse_args(parser) + parser = TissueNetDataModule.add_model_specific_args(parser) + + args = parser.parse_args() + + # init module + model = SSLTrainer.from_argparse_args(args) + + datamodule = TissueNetDataModule.from_argparse_args(args) + anchor_val = AnchorSegmentationValidation(run_segmentation=False) + lr_logger = LearningRateMonitor(logging_interval='step') + model_saver = SaveModelOnValidation() + + # init trainer + trainer = pl.Trainer.from_argparse_args(args) + + trainer.callbacks.append(model_saver) + trainer.callbacks.append(anchor_val) + trainer.callbacks.append(lr_logger) + trainer.fit(model, datamodule) diff --git a/colocseg/train_ssl_ctc.py b/colocseg/train_ssl_ctc.py new file mode 100644 index 0000000..f6faed3 --- /dev/null +++ b/colocseg/train_ssl_ctc.py @@ -0,0 +1,36 @@ +from argparse import ArgumentParser +from colocseg.datamodules import CTCDataModule +from colocseg.evaluation import AnchorSegmentationValidation +from colocseg.trainingmodules import SSLTrainer +from colocseg.utils import SaveModelOnValidation +from pytorch_lightning.callbacks import LearningRateMonitor +import pytorch_lightning as pl + + +if __name__ == '__main__': + + parser = ArgumentParser() + + pl.utilities.seed.seed_everything(42) + parser = SSLTrainer.add_model_specific_args(parser) + parser = pl.Trainer.add_argparse_args(parser) # , logger=WandbLogger(project='SSLAnchor')) + parser = CTCDataModule.add_argparse_args(parser) + parser = CTCDataModule.add_model_specific_args(parser) + + args = parser.parse_args() + + # init module + model = SSLTrainer.from_argparse_args(args) + + datamodule = CTCDataModule.from_argparse_args(args) + anchor_val = AnchorSegmentationValidation(run_segmentation=False) + lr_logger = LearningRateMonitor(logging_interval='step') + model_saver = SaveModelOnValidation() + + # init trainer + trainer = pl.Trainer.from_argparse_args(args) + + trainer.callbacks.append(model_saver) + trainer.callbacks.append(anchor_val) + trainer.callbacks.append(lr_logger) + trainer.fit(model, datamodule) diff --git a/colocseg/train_supervised.py b/colocseg/train_supervised.py new file mode 100644 index 0000000..86a2318 --- /dev/null +++ b/colocseg/train_supervised.py @@ -0,0 +1,54 @@ +from argparse import ArgumentParser +from colocseg.datamodules import PartiallySupervisedDataModule +from colocseg.evaluation import StardistSegmentationValidation, CellposeSegmentationValidation +from colocseg.trainingmodules import PartiallySupervisedTrainer +from colocseg.utils import SaveModelOnValidation +from colocseg.test_inference import test_from_checkpoint +from pytorch_lightning.callbacks import LearningRateMonitor +import pytorch_lightning as pl + + +if __name__ == '__main__': + + parser = ArgumentParser() + + # pl.utilities.seed.seed_everything(42) + parser = PartiallySupervisedTrainer.add_model_specific_args(parser) + parser = pl.Trainer.add_argparse_args(parser) + parser = PartiallySupervisedDataModule.add_argparse_args(parser) + parser = PartiallySupervisedDataModule.add_model_specific_args(parser) + + args = parser.parse_args() + + # init module + model = PartiallySupervisedTrainer.from_argparse_args(args) + datamodule = PartiallySupervisedDataModule.from_argparse_args(args) + + if args.loss_name_super == 'StardistLoss': + seg_val = StardistSegmentationValidation(f"Colocseg_TissueNet_{args.limit}", min_size=0) + elif args.loss_name_super == 'CellposeLoss': + seg_val = CellposeSegmentationValidation(f"Colocseg_TissueNet_{args.limit}", min_size=0) + else: + raise NotImplementedError("Validator for loss not implemented") + + lr_logger = LearningRateMonitor(logging_interval='step') + model_saver = SaveModelOnValidation() + + # init trainer + trainer = pl.Trainer.from_argparse_args(args) + + trainer.callbacks.append(model_saver) + trainer.callbacks.append(seg_val) + trainer.callbacks.append(lr_logger) + trainer.fit(model, datamodule) + trainer.validate(model, datamodule) + + # compute test scores + if args.tissue_type == 'all': + test_from_checkpoint(".", args, checkpoint_index=0, target_type=args.target_type, tissue_type='all') + test_from_checkpoint(".", args, checkpoint_index=0, target_type=args.target_type, tissue_type="immune") + test_from_checkpoint(".", args, checkpoint_index=0, target_type=args.target_type, tissue_type="pancreas") + test_from_checkpoint(".", args, checkpoint_index=0, target_type=args.target_type, tissue_type="lung") + test_from_checkpoint(".", args, checkpoint_index=0, target_type=args.target_type, tissue_type="skin") + else: + test_from_checkpoint(".", args, checkpoint_index=0, target_type=args.target_type, tissue_type=args.tissue_type) diff --git a/colocseg/trainingmodules.py b/colocseg/trainingmodules.py new file mode 100644 index 0000000..b0180e8 --- /dev/null +++ b/colocseg/trainingmodules.py @@ -0,0 +1,401 @@ +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +import zarr +from skimage.io import imsave +from torch.optim.lr_scheduler import MultiStepLR, ReduceLROnPlateau + +from colocseg.loss import AnchorLoss +from colocseg.loss_supervised import SupervisedInstanceEmbeddingLoss +from colocseg.model import Unet2D +from colocseg.utils import BuildFromArgparse, import_by_string, zarr_append + + +class SSLTrainer(pl.LightningModule, BuildFromArgparse): + def __init__(self, + in_channels=1, + out_channels=18, + initial_lr=1e-4, + regularization=1e-4, + temperature=10, + coordinate_offset_after_valid_unet=8, + lr_milestones=(100,) + ): + + super().__init__() + + self.save_hyperparameters() + self.out_channels = out_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.initial_lr = initial_lr + self.lr_milestones = list(int(_) for _ in lr_milestones) + self.regularization = regularization + self.temperature = temperature + self.coordinate_offset_after_valid_unet = coordinate_offset_after_valid_unet + self.build_models() + self.build_loss() + + @staticmethod + def add_model_specific_args(parser): + parser.add_argument('--out_channels', type=int) + parser.add_argument('--in_channels', type=int, default=1) + parser.add_argument('--initial_lr', type=float, default=1e-4) + parser.add_argument('--regularization', type=float, default=1e-4) + parser.add_argument('--lr_milestones', nargs='*', default=[10000, 20000]) + parser.add_argument('--temperature', type=float, default=10) + parser.add_argument('--coordinate_offset_after_valid_unet', type=int, default=8) + return parser + + def forward(self, x): + return self.mini_unet(x) + + def build_models(self): + self.mini_unet = Unet2D( + self.in_channels, + self.out_channels, + num_fmaps=256) + + def build_loss(self, ): + self.validation_loss = SupervisedInstanceEmbeddingLoss(30.) + self.anchor_loss = AnchorLoss(self.temperature) + + def training_step(self, batch, batch_nb): + + x, anchor_coordinates, refernce_coordinates = batch + emb_relative = self.mini_unet.forward(x) + emb_anchor = self.mini_unet.select_and_add_coordinates(emb_relative, anchor_coordinates) + emb_ref = self.mini_unet.select_and_add_coordinates(emb_relative, refernce_coordinates) + anchor_loss = self.anchor_loss(emb_anchor, emb_ref) + + self.log_images(x, emb_relative) + + self.log( + 'anchor_loss', + anchor_loss.detach(), + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True) + self.log( + 'mean_offset', + emb_relative.detach().mean(), + on_step=True, + on_epoch=False, + prog_bar=True, + logger=True) + self.log( + 'max_offset', + emb_relative.detach().abs().max(), + on_step=True, + on_epoch=False, + prog_bar=True, + logger=True) + self.log( + 'anchor_loss_temperature', + self.anchor_loss.temperature, + on_step=True, + prog_bar=True, + logger=True) + + if self.regularization > 0.: + reg_loss = self.regularization * emb_anchor.norm(2, dim=-1).sum() + loss = anchor_loss + reg_loss + self.log('reg_loss', reg_loss.detach(), on_step=True, prog_bar=True, logger=True) + else: + loss = anchor_loss + + tensorboard_logs = {'train_loss': loss.detach()} + tensorboard_logs['iteration'] = self.global_step + return {'loss': loss, 'log': tensorboard_logs} + + def validation_step(self, batch, batch_nb): + + x, anchor_coordinates, refernce_coordinates, y = batch + + with torch.no_grad(): + + # pad image to get full output + anchor_coordinates = anchor_coordinates + self.coordinate_offset_after_valid_unet + coavu = self.coordinate_offset_after_valid_unet + p2d = (coavu, coavu, coavu, coavu) + x_padded = F.pad(x, p2d, mode='reflect') + + embedding = self.mini_unet.forward(x_padded) + emb_anchor = self.mini_unet.select_and_add_coordinates(embedding, anchor_coordinates) + emb_ref = self.mini_unet.select_and_add_coordinates(embedding, refernce_coordinates) + loss = self.anchor_loss(emb_anchor, emb_ref) + + self.log('val_anchor_loss', loss.detach(), on_epoch=True, prog_bar=False, logger=True) + for margin in [1., 5., 10, 20., 40]: + self.validation_loss.push_margin = margin + absoute_embedding = self.anchor_loss.absoute_embedding(emb_anchor, emb_ref) + pull_loss, push_loss = self.validation_loss(emb_anchor, anchor_coordinates, y, split_pull_push=True) + + self.log( + f'val_clustering_loss_margin_pull_{margin}', + pull_loss.detach(), + on_epoch=True, + prog_bar=False, + logger=True) + self.log( + f'val_clustering_loss_margin_push_{margin}', + push_loss.detach(), + on_epoch=True, + prog_bar=False, + logger=True) + self.log( + f'val_clustering_loss_margin_both_{margin}', + (pull_loss + push_loss).detach(), + on_epoch=True, + prog_bar=False, + logger=True) + + return {'val_loss': loss} + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.initial_lr, weight_decay=0.01) + # scheduler = MultiStepLR(optimizer, milestones=self.lr_milestones, gamma=0.1) + return optimizer#], [scheduler] + + def log_images(self, raw, network_prediction): + from pathlib import Path + import numpy as np + import math + if self.global_step > 0 and (math.log2(self.global_step).is_integer() or self.global_step % 10000 == 0): + Path("img").mkdir(exist_ok=True) + x = raw.detach().cpu().numpy() + x = np.reshape(np.transpose(x, (0, 2, 1, 3)), (-1, x.shape[1] * x.shape[3])) + imsave(f"img/raw_ssl_{self.global_step}.png", x, check_contrast=False) + + x = network_prediction.detach().cpu().numpy() + x = np.reshape(np.transpose(x, (0, 2, 1, 3)), (-1, x.shape[1] * x.shape[3])) + imsave(f"img/pred_ssl_{self.global_step}.png", x, check_contrast=False) + + def log_now(self, val=False): + + if val: + if self.last_val_log == self.global_step: + return False + else: + self.last_val_log = self.global_step + return True + + if self.global_step > 1024: + return self.global_step % 2048 == 0 + else: + return self.global_step % 64 == 0 + + +class PartiallySupervisedTrainer(pl.LightningModule, BuildFromArgparse): + def __init__(self, + in_channels=1, + out_channels=17, + loss_name_super="StardistLoss", + loss_name_aux="StardistLoss", + aux_channels=17, + unet_head_type="single", + unet_fmap_inc_factor=3, + loss_alpha=0.01, + loss_delay=0, + loss_same_channel=False, + valid_crop=46, + initial_lr=1e-4, + lr_milestones=(100,), + train_without_pseudo_gt=False, + save_batches=True + ): + + super().__init__() + + self.save_hyperparameters() + self.out_channels = out_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.loss_name_super = loss_name_super + self.loss_name_aux = loss_name_aux + self.aux_channels = aux_channels + self.loss_alpha = loss_alpha + self.loss_delay = loss_delay + self.loss_same_channel = loss_same_channel + self.unet_head_type = unet_head_type + self.unet_fmap_inc_factor = unet_fmap_inc_factor + self.valid_crop = valid_crop + self.initial_lr = initial_lr + self.lr_milestones = list(int(_) for _ in lr_milestones) + self.metrics = {} + self.train_without_pseudo_gt = train_without_pseudo_gt + self.save_batches = save_batches + + self.build_models() + self.build_loss() + + @staticmethod + def add_model_specific_args(parser): + parser.add_argument('--out_channels', type=int) + parser.add_argument('--in_channels', type=int, default=1) + parser.add_argument('--valid_crop', type=int, default=46) + parser.add_argument('--initial_lr', type=float, default=1e-4) + parser.add_argument('--loss_name_super', type=str, default="StardistLoss") + parser.add_argument('--loss_name_aux', type=str, default="StardistLoss") + parser.add_argument('--unet_head_type', type=str, default="single") + parser.add_argument('--aux_channels', type=int, default=17) + parser.add_argument('--loss_alpha', type=float, default=0.01) + parser.add_argument('--loss_delay', type=int, default=0) + parser.add_argument('--unet_fmap_inc_factor', type=int, default=3) + parser.add_argument('--loss_same_channel', action='store_true') + + parser.add_argument('--lr_milestones', nargs='*', default=[10000, 20000]) + parser.add_argument('--train_without_pseudo_gt', action='store_true') + + return parser + + def forward(self, x): + return self.maxi_unet(x) + + def build_models(self): + self.maxi_unet = Unet2D( + self.in_channels, + self.out_channels, + self.aux_channels, + head_type=self.unet_head_type, + num_fmaps=32, + fmap_inc_factor=self.unet_fmap_inc_factor, + depth=3) + + def build_loss(self, ): + loss_class_super = import_by_string(f'colocseg.loss.{self.loss_name_super}') + self.criterion_super = loss_class_super() + loss_class_aux = import_by_string(f'colocseg.loss.{self.loss_name_aux}') + self.criterion_aux = loss_class_aux() + + def log_metrics(self): + for k in self.metrics: + try: + self.log(k, self.metrics[k], prog_bar=k == "f1") + except: + pass + self.metrics = {} + + def crop_to_valid(self, tensor): + return tensor[..., self.valid_crop:-self.valid_crop, self.valid_crop:-self.valid_crop] + + def log_batch(self, filename, log_now=False, **kwargs): + if self.save_batches and self.global_step % 5000 == 0 or log_now: + zout = zarr.open(filename) + for k in kwargs: + if kwargs[k] is not None: + zout[k] = kwargs[k].detach().cpu().numpy() + + def log_images(self, name, raw, network_prediction, target, gtseg): + from pathlib import Path + import numpy as np + import math + if self.global_step > 0 and math.log2(self.global_step).is_integer(): + Path("img").mkdir(exist_ok=True) + x = raw.detach().cpu().numpy() + x = np.reshape(np.transpose(x, (0, 2, 1, 3)), (-1, x.shape[1] * x.shape[3])) + imsave(f"img/raw_{name}_{self.global_step}.png", x, check_contrast=False) + + fgbg0 = network_prediction.detach().cpu()[:, 0].sigmoid().numpy() + fgbg0 = np.reshape(fgbg0, (-1, fgbg0.shape[-1])) + pred0 = network_prediction.detach().cpu()[:, 1] + pred0 = np.reshape(pred0, (-1, pred0.shape[-1])) + pred0 -= pred0.min() + pred0 /= pred0.max() + 1e-8 + imsave(f"img/pred_{name}_{self.global_step}.png", np.concatenate( + (fgbg0, + 0 * fgbg0[:1], 0 * fgbg0[:1] + 1, 0 * fgbg0[:1], + pred0), + axis=0), + check_contrast=False) + + g = gtseg.detach().cpu().reshape(-1, gtseg.shape[-1]).numpy() + # g /= g.max() + 0.01 + imsave(f"img/gt_{name}_{self.global_step}.png", g, check_contrast=False) + for c in range(target.shape[1]): + t = target[:, c].detach().cpu().reshape(-1, target.shape[-1]).numpy() + t /= t.max() + imsave(f"img/target_{name}_{self.global_step}_{c}.png", t, check_contrast=False) + + + def training_step(self, batch, batch_nb): + + add_aux_loss = self.loss_alpha > 0 and not self.train_without_pseudo_gt and self.loss_delay <= self.global_step + add_super_loss = self.loss_alpha < 1. and (self.loss_delay >= 0 or ( + self.loss_delay < 0 and -self.loss_delay <= self.global_step)) + + batch_supervised, batch_aux = batch + loss = 0. + + if add_aux_loss: + x_aux, target_aux, aux_segmentation = batch_aux + + network_prediction_aux_on_aux, network_prediction_super_on_aux = self.maxi_unet.forward(x_aux) + network_prediction_aux = network_prediction_aux_on_aux + if self.loss_same_channel: + network_prediction_aux = network_prediction_super_on_aux + + self.log_batch(f"batch_{self.global_step}_aux.zarr", + x_aux=x_aux, + network_prediction_aux=network_prediction_aux, + target_aux=target_aux, + aux_segmentation=aux_segmentation[:, None] if torch.is_tensor(aux_segmentation) else None) + target_aux = self.crop_to_valid(target_aux) + + if self.global_step % 5000 == 0: + def log_hook_full(grad_input): + outzarr = zarr.open(f"grad_{self.global_step}_aux.zarr", "w") + zarr_append("grad", grad_input.detach().cpu().numpy()[None], outzarr) + zarr_append("x_aux", (x_aux).detach().cpu().numpy()[None], outzarr) + zarr_append("target_aux", (target_aux).detach().cpu().numpy()[None], outzarr) + zarr_append("network_prediction_aux", network_prediction_aux.detach().cpu().numpy()[None], outzarr) + handle.remove() + handle = network_prediction_aux.register_hook(log_hook_full) + + self.log_images("aux", x_aux, network_prediction_aux, target_aux, aux_segmentation) + loss_pseudo = self.criterion_aux(network_prediction_aux, target_aux) + + self.log("train_loss_pseudo", loss_pseudo.detach().item(), prog_bar=True) + loss = loss + (self.loss_alpha * loss_pseudo) + + if add_super_loss: + x_super, target_super, gt_segmentation = batch_supervised + network_prediction_super = self.maxi_unet.forward(x_super)[1] + self.log_batch(f"batch_{self.global_step}_super.zarr", + x_super=x_super, + network_prediction_super=network_prediction_super, + target_super=target_super, + gt_segmentation=gt_segmentation[:, None]) + target_super = self.crop_to_valid(target_super) + self.log_images("super", x_super, network_prediction_super, target_super, gt_segmentation) + + if self.global_step % 5000 == 0: + def log_hook_full(grad_input): + outzarr = zarr.open(f"grad_{self.global_step}_super.zarr", "w") + zarr_append("grad", grad_input.detach().cpu().numpy()[None], outzarr) + zarr_append("x_super", (x_super).detach().cpu().numpy()[None], outzarr) + zarr_append("target_super", (target_super).detach().cpu().numpy()[None], outzarr) + zarr_append("network_prediction_aux", + network_prediction_super.detach().cpu().numpy()[None], outzarr) + handle2.remove() + handle2 = network_prediction_super.register_hook(log_hook_full) + + loss_super = self.criterion_super(network_prediction_super, target_super) + self.log("train_loss_super", loss_super.detach().item(), prog_bar=True) + loss = loss + ((1 - self.loss_alpha) * loss_super) + + self.log("loss_alpha", self.loss_alpha, prog_bar=False) + tensorboard_logs = {'train_loss': loss.detach().item()} + self.log_metrics() + self.log("train_loss", loss.detach().item()) + + return {'loss': loss, 'log': tensorboard_logs} + + def validation_step(self, batch, batch_nb): + return 0. + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.initial_lr, weight_decay=0.0) + scheduler = MultiStepLR(optimizer, milestones=self.lr_milestones, gamma=0.1) + return [optimizer], [scheduler] diff --git a/colocseg/transforms.py b/colocseg/transforms.py new file mode 100644 index 0000000..07cb458 --- /dev/null +++ b/colocseg/transforms.py @@ -0,0 +1,126 @@ +import numpy as np +import stardist +from embeddingutils.transforms import Segmentation2AffinitiesWithPadding +from inferno.io.transform import Transform +from cellpose.dynamics import masks_to_flows + +class ThreeclassTf(Transform): + """Convert segmentation to 3 class""" + + def __init__(self, inner_distance=4): + super().__init__() + self.inner_distance = inner_distance + + def tensor_function(self, gt): + gt_stardist = stardist.geometry.star_dist(gt, n_rays=8) + background = gt == 0 + inner = gt_stardist.min(axis=-1) > self.inner_distance + # classes 0: boundaries, 1: inner_cell, 2: background + threeclass = (2 * background) + inner + return threeclass.astype(np.long) + + +class StardistTf(Transform): + """Convert segmentation to stardist""" + + def __init__(self, n_rays=16, fill_label_holes=False): + super().__init__() + self.n_rays = n_rays + self.fill_label_holes = fill_label_holes + + def tensor_function(self, gt): + # gt = measure.label(gt) + if self.fill_label_holes: + gt = stardist.fill_label_holes(gt) + dist = stardist.geometry.star_dist(gt, n_rays=self.n_rays) + dist_mask = stardist.utils.edt_prob(gt) + + if gt.min() < 0: + # ignore label found + ignore_mask = gt < 0 + dist[ignore_mask] = 0 + dist_mask[ignore_mask] = -1 + + dist_mask = dist_mask[None] + dist = np.transpose(dist, (2, 0, 1)) + + mask_and_dist = np.concatenate([dist_mask, dist], axis=0) + return mask_and_dist + + +def get_offsets(): + return np.array(((-1, 0), + (0, -1), + (0., -4.), + (-3., -3.), + (-4., -0.), + (-3., 3.), + (0, -8), + (-8, 0)), int) + + +class AffinityTf(Transform): + """Convert segmentation to stardist""" + + def __init__(self, offsets=None): + super().__init__() + self.offsets = get_offsets() if offsets is None else offsets + + self.seg2aff = Segmentation2AffinitiesWithPadding( + self.offsets, + retain_segmentation=False, + segmentation_to_binary=False, + ignore_label=-1) + + def tensor_function(self, gt): + aff = 1 - self.seg2aff.tensor_function(gt).astype(np.float32) + return np.concatenate(((gt > 0)[None].astype(aff.dtype), aff), axis=0) + + +class CellposeTf(Transform): + """Convert segmentation to cellpose targets + target: ND - array, float + transformed labels in array[nimg x nchan x xy[0] x xy[1]] + target[:, 0] distance fields + target[:, 1:3] flow fields + target[:, 3] boundary fields + target[:, 4] boundary - emphasized weights""" + + def __init__(self): + super().__init__() + + def tensor_function(self, gt): + + ignore_mask = gt < 0 + gt = gt.copy() + gt[ignore_mask] = 0 + + seg, distance_field, probs, flow = masks_to_flows(gt) + # boundary map + bd = (distance_field==1) + bd[bd==0] = 0 + # cell mask + w = 0.1 + (seg>0) + w[ignore_mask] = 0 + + out = np.concatenate([distance_field[None], + flow, + bd[None], + w[None]], axis=0) + out = out.astype(np.float32) + return out + + + # trgt = dynamics.masks_to_flows(train_labels[0], dists=None, use_gpu=False, device=None, omni=True) + # dist = trgt[1] + # heat = trgt[2] + # flow_fields = trgt[3] + + # label, label > 0, utils.distance_to_boundary(label)), + + # mask = lbl[6] = train_labels[0] > 0 + # bg_edt = edt.edt(mask<0.5,black_border=True) #last arg gives weight to the border, which seems to always lose + # cutoff = 9 + # lbl[7] = (gaussian(1-np.clip(bg_edt,0,cutoff)/cutoff, 1)+0.5) + + # return mask_and_dist diff --git a/colocseg/utils.py b/colocseg/utils.py new file mode 100644 index 0000000..980ca27 --- /dev/null +++ b/colocseg/utils.py @@ -0,0 +1,721 @@ +import inspect +import json +import numbers +import os +from time import time +import random +from argparse import ArgumentParser +from functools import partial + +import gunpowder as gp +import imgaug as ia +import matplotlib +import numpy as np +import scipy +import scipy.sparse as sparse +import torch +from gunpowder.batch_request import BatchRequest +from imgaug import augmenters as iaa +from inferno.io.transform.base import Transform +from PIL import Image +from pytorch_lightning.callbacks import Callback +from scipy import ndimage +from skimage import measure +from skimage.filters import rank +from skimage.measure import label, regionprops +from skimage.morphology import disk +from skimage.segmentation import watershed +from skimage.transform import rescale +from sklearn.cluster import DBSCAN +from torch.nn import functional as F +from torch.utils.data import Sampler + + +def offset_slice(offset, reverse=False, extra_dims=0): + def shift(o): + if o == 0: + return slice(None) + elif o > 0: + return slice(o, None) + else: + return slice(0, o) + if not reverse: + return (slice(None),) * extra_dims + tuple(shift(int(o)) for o in offset) + else: + return (slice(None),) * extra_dims + tuple(shift(-int(o)) for o in offset) + + +def label2color(label): + + if isinstance(label, Image.Image): + label = np.array(label) + if len(label.shape) == 3: + label = label[..., 0] + + cmap = matplotlib.cm.get_cmap('nipy_spectral') + shuffle_labels = np.concatenate( + ([0], np.random.permutation(label.max()) + 1)) + label = shuffle_labels[label] + return cmap(label / (label.max() + 1)).transpose(2, 0, 1) + + +def try_remove(filename): + try: + os.remove(filename) + except OSError: + pass + + +def visnorm(x): + x = x - x.min() + x = x / x.max() + return x + + +def vis(x, normalize=True): + if isinstance(x, Image.Image): + x = np.array(x) + + assert(len(x.shape) in [2, 3]) + + if len(x.shape) == 2: + x = x[None] + else: + if x.shape[0] not in [1, 3]: + if x.shape[2] in [1, 3]: + x = x.transpose(2, 0, 1) + else: + raise Exception( + "can not visualize array with shape ", x.shape) + + if normalize: + with torch.no_grad(): + visnorm(x) + + return x + + +def log_img(name, img, pl_module): + pl_module.logger.experiment.add_image(name, img, pl_module.global_step) + + +def is_jsonable(x): + try: + json.dumps(x) + return True + except BaseException: + return False + + +def save_args(args, directory): + os.mkdir(directory) + log_out = os.path.join(directory, "commandline_args.txt") + serializable_args = {key: value for (key, value) in args.__dict__.items() if is_jsonable(value)} + + with open(log_out, 'w') as f: + json.dump(serializable_args, f, indent=2) + + +def adapted_rand(seg, gt, all_stats=False, ignore_label=True): + """Compute Adapted Rand error. + Parameters + ---------- + seg : np.ndarray + the segmentation to score, where each value is the label at that point + gt : np.ndarray, same shape as seg + the groundtruth to score against, where each value is a label + all_stats : boolean, optional + whether to also return precision and recall as a 3-tuple with rand_error + ignore_label: boolean, optional + whether to ignore the zero label + Returns + ------- + are : float + The adapted Rand error; equal to $1 - \frac{2pr}{p + r}$, + where $p$ and $r$ are the precision and recall described below. + prec : float, optional + The adapted Rand precision. (Only returned when `all_stats` is ``True``.) + rec : float, optional + The adapted Rand recall. (Only returned when `all_stats` is ``True``.) + """ + # segA is truth, segB is query + segA = np.ravel(gt) + segB = np.ravel(seg) + n = segA.size + + n_labels_A = int(np.amax(segA)) + 1 + n_labels_B = int(np.amax(segB)) + 1 + + ones_data = np.ones(n) + + p_ij = sparse.csr_matrix( + (ones_data, (segA[:], segB[:])), shape=(n_labels_A, n_labels_B)) + + if ignore_label: + a = p_ij[1:n_labels_A, :] + b = p_ij[1:n_labels_A, 1:n_labels_B] + c = p_ij[1:n_labels_A, 0].todense() + else: + a = p_ij[:n_labels_A, :] + b = p_ij[:n_labels_A, 1:n_labels_B] + c = p_ij[:n_labels_A, 0].todense() + d = b.multiply(b) + + a_i = np.array(a.sum(1)) + b_i = np.array(b.sum(0)) + + sumA = np.sum(a_i * a_i) + sumB = np.sum(b_i * b_i) + (np.sum(c) / n) + sumAB = np.sum(d) + (np.sum(c) / n) + + precision = sumAB / sumB + recall = sumAB / sumA + + fScore = 2.0 * precision * recall / (precision + recall) + are = 1.0 - fScore + + if all_stats: + return {"are": are, + "precision": precision, + "recall": recall} + else: + return are + + +def offset_from_direction(direction, max_direction=8., distance=10): + angle = (direction / max_direction) + angle = 2 * np.pi * angle + + x_offset = int(0.75 * distance * np.sin(angle)) + y_offset = int(0.75 * distance * np.cos(angle)) + + x_offset += random.randint(-int(0.15 * distance), + +int(0.15 * distance)) + y_offset += random.randint(-int(0.15 * distance), + +int(0.15 * distance)) + + return x_offset, y_offset + + +def random_offset(distance=10): + angle = 2 * np.pi * np.random.uniform() + distance = np.random.uniform(low=1., high=distance) + + x_offset = int(distance * np.sin(angle)) + y_offset = int(distance * np.cos(angle)) + + return x_offset, y_offset + +# if y_hat.requires_grad: +# def log_hook(grad_input): +# # torch.cat((grad_input.detach().cpu(), y_hat.detach().cpu()), dim=0) +# grad_input_batch = torch.cat(tuple(torch.cat(tuple(vis(e_0[c]) for c in range(e_0.shape[0])), dim=1) for e_0 in grad_input), dim=2) +# self.logger.experiment.add_image(f'train_regression_grad', grad_input_batch, self.global_step) +# handle.remove() + +# handle = y_hat.register_hook(log_hook) + + +class UpSample(gp.nodes.BatchFilter): + + def __init__(self, source, factor, target): + + assert isinstance(source, gp.ArrayKey) + assert isinstance(target, gp.ArrayKey) + assert ( + isinstance(factor, numbers.Number) or isinstance(factor, tuple)), ( + "Scaling factor should be a number or a tuple of numbers.") + + self.source = source + self.factor = factor + self.target = target + + def setup(self): + + spec = self.spec[self.source].copy() + spec.roi = spec.roi * self.factor + self.provides(self.target, spec) + self.enable_autoskip() + + def prepare(self, request): + + deps = gp.BatchRequest() + sdep = request[self.target] + sdep.roi = sdep.roi / self.factor + deps[self.source] = sdep + return deps + + def process(self, batch, request): + outputs = gp.Batch() + + # logger.debug("upsampeling %s with %s", self.source, self.factor) + + # resize + data = batch.arrays[self.source].data + data = rescale(data, self.factor) + + # create output array + spec = self.spec[self.target].copy() + spec.roi = request[self.target].roi + outputs.arrays[self.target] = gp.Array(data, spec) + + return outputs + + +class AbsolutIntensityAugment(gp.nodes.BatchFilter): + + def __init__(self, array, scale_min, scale_max, shift_min, shift_max): + self.array = array + self.scale_min = scale_min + self.scale_max = scale_max + self.shift_min = shift_min + self.shift_max = shift_max + + def setup(self): + self.enable_autoskip() + self.updates(self.array, self.spec[self.array]) + + def prepare(self, request): + deps = BatchRequest() + deps[self.array] = request[self.array].copy() + return deps + + def process(self, batch, request): + + raw = batch.arrays[self.array] + + raw.data = self.__augment(raw.data, + np.random.uniform(low=self.scale_min, high=self.scale_max), + np.random.uniform(low=self.shift_min, high=self.shift_max)) + + # clip values, we might have pushed them out of [0,1] + raw.data[raw.data > 1] = 1 + raw.data[raw.data < 0] = 0 + + def __augment(self, a, scale, shift): + + return a * scale + shift + + +class Patchify(object): + """ Adapted from + https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/8a4cf8f61644c28d6df54ccffe3a52d6f5fce5a6/pl_bolts/transforms/self_supervised/ssl_transforms.py#L62 + This implementation adds a dilation parameter + """ + + def __init__(self, patch_size, overlap_size, dilation): + self.patch_size = patch_size + self.overlap_size = self.patch_size - overlap_size + self.dilation = dilation + + def patchify_2d(self, x): + x = x.unsqueeze(0) + b, c, h, w = x.size() + + # patch up the images + # (b, c, h, w) -> (b, c*patch_size, L) + x = F.unfold(x, + kernel_size=self.patch_size, + stride=self.overlap_size, + dilation=self.dilation) + + # (b, c*patch_size, L) -> (b, nb_patches, width, height) + x = x.transpose(2, 1).contiguous().view(b, -1, self.patch_size, self.patch_size) + + # reshape to have (b x patches, c, h, w) + x = x.view(-1, c, self.patch_size, self.patch_size) + + x = x.squeeze(0) + + return x + + def __call__(self, x): + if x.dim() == 3: + return self.patchify_2d(x) + else: + raise NotImplementedError("patchify is only implemented for 2d images") + + +class BuildFromArgparse(object): + @classmethod + def from_argparse_args(cls, args, **kwargs): + + if isinstance(args, ArgumentParser): + args = cls.parse_argparser(args) + params = vars(args) + + # we only want to pass in valid DataModule args, the rest may be user specific + valid_kwargs = inspect.signature(cls.__init__).parameters + datamodule_kwargs = dict( + (name, params[name]) for name in valid_kwargs if name in params + ) + datamodule_kwargs.update(**kwargs) + + return cls(**datamodule_kwargs) + + +def quantil_normalize(tensor, pmin=3, pmax=99.8, clip=4., + eps=1e-20, dtype=np.float32, axis=None): + mi = np.percentile(tensor, pmin, axis=axis, keepdims=True) + ma = np.percentile(tensor, pmax, axis=axis, keepdims=True) + + if dtype is not None: + tensor = tensor.astype(dtype, copy=False) + mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype, copy=False) + ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype, copy=False) + eps = dtype(eps) + + try: + import numexpr + x = numexpr.evaluate("(tensor - mi) / ( ma - mi + eps )") + except ImportError: + x = (tensor - mi) / (ma - mi + eps) + + if clip is not None: + x = np.clip(x, -clip, clip) + + return x + + +class QuantileNormalize(Transform): + """Percentile-based image normalization + (adopted from https://github.com/CSBDeep/CSBDeep/blob/master/csbdeep/utils/utils.py)""" + + def __init__(self, pmin=0.6, pmax=99.8, clip=4., + eps=1e-20, dtype=np.float32, + axis=None, **super_kwargs): + """ + Parameters + ---------- + pmin: float + minimum percentile value. The pmin percentile value of the input tensor + is mapped to 0. + pmax: float + maximum percentile value. The pmax percentile value of the input tensor + is mapped to 1. + clip: bool + Clip all values outside of the percentile range to (0, 1) + axis: int, tuple or None + spatial dimensions considerered for the normalization + super_kwargs : dict + Kwargs to the superclass `inferno.io.transform.base.Transform`. + """ + super().__init__(**super_kwargs) + self.pmin = pmin + self.pmax = pmax + self.clip = clip + self.axis = axis + self.dtype = dtype + self.eps = eps + + def tensor_function(self, tensor): + return quantil_normalize(tensor, pmin=self.pmin, pmax=self.pmax, clip=self.clip, + axis=self.axis, dtype=self.dtype, eps=self.eps) + + +class QuantileNormalizeTorchTransform(object): + """Crop randomly the image in a sample. + + Args: + output_size (tuple or int): Desired output size. If int, square crop + is made. + """ + + def __init__(self, pmin=3, pmax=99.8, clip=4., + eps=1e-20, axis=None): + self.pmin = pmin + self.pmax = pmax + self.clip = clip + self.axis = axis + self.eps = eps + + def __call__(self, sample): + return quantil_normalize(sample, pmin=self.pmin, pmax=self.pmax, clip=self.clip, + axis=self.axis, dtype=None, eps=self.eps).float() + + +def pre_channel(img, fun): + if len(img.shape) == 3: + return np.stack(tuple(fun(_) for _ in img), axis=0) + else: + return fun(img) + + +class Scale(Transform): + """ Rescale patch of by constant factor""" + + def __init__(self, scale, **super_kwargs): + super().__init__(**super_kwargs) + self.scale = scale + + def batch_function(self, inp): + + image, segmentation = inp + + if self.scale != 1.: + image = pre_channel( + image, + partial(rescale, + scale=self.scale, + order=3, + anti_aliasing=True)) + + segmentation = pre_channel( + segmentation, + partial(rescale, + scale=self.scale, + order=0)) + + return image.astype(np.float32), segmentation.astype(np.float32) + + +def import_by_string(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +class SaveModelOnValidation(Callback): + + def __init__(self, run_segmentation=False, device='cpu'): + self.run_segmentation = run_segmentation + self.device = device + super().__init__() + + def on_validation_epoch_start(self, trainer, pl_module): + """Called when the validation loop begins.""" + model_directory = os.path.abspath(os.path.join(pl_module.logger.log_dir, + os.pardir, + os.pardir, + "models")) + os.makedirs(model_directory, exist_ok=True) + if hasattr(pl_module, "mini_unet"): + model_save_path = os.path.join( + model_directory, f"mini_unet_{pl_module.global_step:08d}_{pl_module.local_rank:02}.torch") + torch.save({"model_state_dict": pl_module.mini_unet.state_dict()}, model_save_path) + + if hasattr(pl_module, "maxi_unet"): + model_save_path = os.path.join( + model_directory, f"maxi_unet_{pl_module.global_step:08d}_{pl_module.local_rank:02}.torch") + torch.save({"model_state_dict": pl_module.maxi_unet.state_dict()}, model_save_path) + + +# Sometimes(0.5, ...) applies the given augmenter in 50% of all cases, +# e.g. Sometimes(0.5, GaussianBlur(0.3)) would blur roughly every second image. +def sometimes(aug): return iaa.Sometimes(0.5, aug) + + +def get_augmentation_transform(simple=False, medium=False, seed=True): + if simple: + return iaa.Sequential([ + # apply the following augmenters to most images + iaa.Fliplr(0.5), # horizontally flip 50% of all images + iaa.Flipud(0.5), # vertically flip 50% of all images + sometimes(iaa.geometric.Rot90(k=ia.ALL, keep_size=False)) + ]) + if medium: + return iaa.Sequential([ + # apply the following augmenters to most images + iaa.Fliplr(0.5), # horizontally flip 50% of all images + iaa.Flipud(0.5), # vertically flip 50% of all images + sometimes(iaa.geometric.Rot90(k=ia.ALL, keep_size=False)), + sometimes(iaa.ElasticTransformation(alpha=(0., 10.), sigma=(3, 8))), + iaa.SomeOf((0, 2), + [ + iaa.GaussianBlur((0, 1.0)), # blur images with a sigma between 0 and 1.0 + iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 1.), + per_channel=False), # add gaussian noise to images + iaa.Multiply((0.8, 1.2), per_channel=False), + iaa.LinearContrast((0.5, 2.0), per_channel=True), # improve or worsen the contrast + ], random_order=True) + ]) + return iaa.Sequential([ + # apply the following augmenters to most images + iaa.Fliplr(0.5), # horizontally flip 50% of all images + iaa.Flipud(0.5), # vertically flip 50% of all images + sometimes(iaa.ElasticTransformation(alpha=(0., 30.), sigma=(3, 8))), + iaa.SomeOf((0, 2), + [ + iaa.geometric.ScaleX(scale=(1., 1.2), order=1, cval=0, mode='constant'), + iaa.geometric.ScaleY(scale=(1., 1.2), order=1, cval=0, mode='constant'), + iaa.GaussianBlur((0, 0.5)), # blur images with a sigma between 0 and 3.0 + iaa.Dropout((0.01, 0.1), per_channel=False), # randomly remove up to 10% of the pixels + iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 1.), + per_channel=False), # add gaussian noise to images + iaa.Multiply((0.8, 1.2), per_channel=False), + iaa.LinearContrast((0.5, 2.0), per_channel=True), # improve or worsen the contrast + ], random_order=True) + ]) + + +def cluster_embeddings(embeddings, eps=1, min_samples=5): + b, c, h, w = embeddings.shape + emb = embeddings.permute(0, 2, 3, 1) + emb = emb.view(b, -1, c) + clusters = cluster_embeddings_flat(emb, eps=eps, min_samples=min_samples) + clusters = [c.reshape(h, w) for c in clusters] + clusters = np.stack(clusters, axis=0) + return clusters + + +def cluster_embeddings_flat(embeddings, eps=1, min_samples=5): + + batch_of_clusters = [] + # we assume input embeddings are in the form (b, p, c) + start_label = 0 + for emb in embeddings: + # emb.shape = (p, c) + emb = emb.detach().cpu().numpy() + clusters = DBSCAN(eps=eps, + min_samples=min_samples).fit_predict(emb) + + # check that clusters are consecutive + a = np.unique(clusters) + assert((a < 0).all() or a[a >= 0].max() + 1 == len(a[a >= 0])) + + # offset labels by previous maximum label + clusters[clusters >= 0] += start_label + # check consistency + if len(batch_of_clusters): + a = np.unique(batch_of_clusters[-1]) + b = np.unique(clusters) + assert(not np.isin(a[a >= 0], b[b >= 0]).any()) + + batch_of_clusters.append(clusters) + + start_label = clusters.max() + 1 + + return batch_of_clusters + + +def remove_border(raw_0, raw_1, seg, min_size=4000): + seg = seg.copy() + for b in range(raw_0.shape[0]): + local_max_0 = rank.maximum(raw_0[b], disk(1)) + local_min_0 = rank.minimum(raw_0[b], disk(1)) + local_max_1 = rank.maximum(raw_1[b], disk(1)) + local_min_1 = rank.minimum(raw_1[b], disk(1)) + m0 = local_max_0 == local_min_0 + m1 = local_max_1 == local_min_1 + mask = np.logical_and(m0, m1) + mask_seg = label(mask) + reg = regionprops(mask_seg) + for props in reg: + if props.area > min_size: + seg[b, mask_seg == props.label] = 0 + return seg + + +def sizefilter(segmentation, min_size, filter_non_connected=True): + + if min_size == 0: + return segmentation + + if filter_non_connected: + filter_labels = measure.label(segmentation, background=0) + else: + filter_labels = segmentation + ids, sizes = np.unique(filter_labels, return_counts=True) + filter_ids = ids[sizes < min_size] + mask = np.in1d(filter_labels, filter_ids).reshape(filter_labels.shape) + segmentation[mask] = 0 + + return segmentation + + +def sizefilter_batch(segmentation, min_size): + + if min_size == 0: + return segmentation + full_mask = [] + for b in range(segmentation.shape[0]): + ids, sizes = np.unique(segmentation[b], return_counts=True) + filter_ids = ids[sizes < min_size] + full_mask.append(np.in1d(segmentation[b], filter_ids).reshape(segmentation[b].shape)) + full_mask = np.stack(full_mask, axis=0) + segmentation = segmentation.copy() + segmentation[full_mask] = 0 + return segmentation + + +def smooth_boundary_fn(segmentation): + segmentation = segmentation.copy() + initialfg = segmentation > 0 + mask = segmentation == 0 + m1 = segmentation[..., :-1] != segmentation[..., 1:] + mask[..., :-1] += m1 + mask[..., 1:] += m1 + + m2 = segmentation[..., :-1, :] != segmentation[..., 1:, :] + mask[..., :-1, :] += m2 + mask[..., 1:, :] += m2 + segmentation[mask] = 0 + + seeds = measure.label(segmentation, background=0) + mask = segmentation > 1 + distance = ndimage.distance_transform_edt(mask) + 0.1 * np.random.rand(*segmentation.shape) + segmentation[:] = watershed(-distance, seeds, mask=initialfg) + + return segmentation + + +def zarr_append(key, data, outzarr, attr=None): + if key not in outzarr: + outzarr.create_dataset(key, data=data, chunks=data.shape, compression="gzip") + if attr is not None: + outzarr[key].attrs[attr[0]] = attr[1] + else: + outzarr[key].append(data) + + +def zarr_insert(key, data, outzarr, attr=None): + if key not in outzarr: + outzarr.create_dataset(key, data=data, chunks=data.shape, compression="gzip") + if attr is not None: + outzarr[key].attrs[attr[0]] = attr[1] + else: + outzarr[key].append(data) + + # if self.global_step % 100 == 0: + # def log_hook_full(grad_input): + # outzarr = zarr.open(f"grad_{self.global_step}.zarr", "w") + # zarr_append("grad", grad_input.detach().cpu().numpy()[None], outzarr) + # zarr_append("x_aux", (x_aux).detach().cpu().numpy()[None], outzarr) + # zarr_append("target_aux", (target_aux).detach().cpu().numpy()[None], outzarr) + # zarr_append("network_prediction_aux", network_prediction_aux.detach().cpu().numpy()[None], outzarr) + # handle.remove() + # handle = network_prediction_aux.register_hook(log_hook_full) + +def read_config_from_script(script_file, parser): + parser + +class CropAndSkipIgnore(): + def __init__(self, crop_fn, valid_crop=46): + self.crop_fn = crop_fn + self.valid_crop = valid_crop + self.max_tries = 100 + + def acceptable(self, seg): + vc = self.valid_crop + # check if crop contains a valid segment in the center + return seg[:, vc:-vc, vc:-vc].max() > 0 + + def __call__(self, image=None, + segmentation_maps=None): + + raw, gtseg = self.crop_fn(image=image, + segmentation_maps=segmentation_maps) + vc = self.valid_crop + if gtseg[:, vc:-vc, vc:-vc].max() <= 0: + print("no crop available", np.unique(gtseg)) + + for _ in range(self.max_tries): + if not self.acceptable(gtseg): + raw, gtseg = self.crop_fn(image=image, + segmentation_maps=segmentation_maps) + else: + break + + return raw, gtseg diff --git a/colocseg/visualizations.py b/colocseg/visualizations.py new file mode 100644 index 0000000..d640b24 --- /dev/null +++ b/colocseg/visualizations.py @@ -0,0 +1,65 @@ +from skimage.io import imsave +import io + +import matplotlib +matplotlib.use('agg') +import matplotlib.image as mpimg + +import matplotlib.pyplot as plt +import numpy as np + +from colocseg.utils import label2color + +def vis_anchor_embedding(embedding, patch_coords, img, grad=None, output_file=None): + # patch_coords.shape = (num_patches, 2) + + if img is not None: + if img.shape[0] not in [3]: + plt.imshow(img[0], cmap='magma', interpolation='nearest') + else: + plt.imshow(np.transpose(img, (1, 2, 0)), interpolation='nearest') + + if isinstance(embedding, list): + for e in embedding: + plt.quiver(patch_coords[:, 0], + patch_coords[:, 1], + e["embedding"][:, 0], + e["embedding"][:, 1], + angles='xy', + scale_units='xy', + scale=1., color=e["color"]) + else: + plt.quiver(patch_coords[:, 0], + patch_coords[:, 1], + embedding[:, 0], + embedding[:, 1], + angles='xy', + scale_units='xy', + scale=1., color='#8fffdd') + + if grad is not None: + plt.quiver(patch_coords[:, 0], + patch_coords[:, 1], + (10 * grad[:, 0]) / (grad[:, :2].max() + 1e-9), + (10 * grad[:, 1]) / (grad[:, :2].max() + 1e-9), + angles='xy', + scale_units='xy', + scale=1., + color='#ff8fa0') + + plt.axis('off') + + if output_file is not None: + if isinstance(output_file, (list, tuple)): + for of in output_file: + plt.savefig(of, dpi=300, bbox_inches='tight') + else: + plt.savefig(output_file, dpi=300, bbox_inches='tight') + + # buf = io.BytesIO() + # plt.savefig(buf, format='png') + # buf.seek(0) + plt.cla() + plt.clf() + plt.close() + # return buf diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..01219f6 --- /dev/null +++ b/environment.yml @@ -0,0 +1,335 @@ +name: cellulus +channels: + - anaconda + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=4.5=1_gnu + - absl-py=0.13.0=pyhd8ed1ab_0 + - affogato=0.3.1=py38h8289551_1 + - aiohttp=3.7.4.post0=py38h497a2fe_0 + - antlr-python-runtime=4.9.2=py38h578d9bd_0 + - anyio=3.3.0=py38h578d9bd_0 + - argh=0.26.2=pyh9f0ad1d_1002 + - argon2-cffi=20.1.0=py38h497a2fe_2 + - asciitree=0.3.3=py_2 + - async-timeout=3.0.1=py_1000 + - async_generator=1.10=py_0 + - attrs=21.2.0=pyhd8ed1ab_0 + - babel=2.9.1=pyh44b312d_0 + - backcall=0.2.0=pyh9f0ad1d_0 + - backports=1.0=py_2 + - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 + - beautifulsoup4=4.9.3=pyha847dfd_0 + - blas=1.0=mkl + - bleach=4.1.0=pyhd8ed1ab_0 + - blinker=1.4=py_1 + - blosc=1.21.0=h8c45485_0 + - boost=1.70.0=py38h9de70de_1 + - boost-cpp=1.70.0=ha2d47e9_1 + - brotli=1.0.9=he6710b0_2 + - brotlipy=0.7.0=py38h497a2fe_1001 + - brunsli=0.1=h2531618_0 + - bzip2=1.0.8=h7b6447c_0 + - c-ares=1.17.1=h7f98852_1 + - ca-certificates=2021.10.8=ha878542_0 + - cached-property=1.5.2=hd8ed1ab_1 + - cached_property=1.5.2=pyha770c72_1 + - cachetools=4.2.2=pyhd8ed1ab_0 + - certifi=2021.10.8=py38h578d9bd_0 + - cffi=1.14.5=py38ha65f79e_0 + - cfitsio=3.470=hf0d0db6_6 + - chardet=4.0.0=py38h578d9bd_1 + - charls=2.2.0=h2531618_0 + - click=8.0.1=py38h578d9bd_0 + - cloudpickle=1.6.0=py_0 + - conda=4.10.1=py38h06a4308_1 + - conda-build=3.21.4=py38h06a4308_0 + - conda-package-handling=1.7.3=py38h27cfd23_1 + - configargparse=1.4=pyhd3eb1b0_0 + - configparser=5.0.2=pyhd8ed1ab_0 + - cryptography=3.4.7=py38ha5dfef3_0 + - cudatoolkit=11.1.74=h6bb024c_0 + - cycler=0.10.0=py38_0 + - cython=0.29.24=py38h709712a_0 + - cytoolz=0.11.0=py38h7b6447c_0 + - dask-core=2021.6.2=pyhd3eb1b0_0 + - dataclasses=0.8=pyhc8e2a94_1 + - dbus=1.13.18=hb2f20db_0 + - debugpy=1.4.1=py38h709712a_0 + - decorator=4.4.2=py_0 + - defusedxml=0.7.1=pyhd8ed1ab_0 + - dill=0.3.4=pyhd8ed1ab_0 + - docker-pycreds=0.4.0=py_0 + - einops=0.3.0=py_0 + - entrypoints=0.3=pyhd8ed1ab_1003 + - expat=2.4.1=h2531618_2 + - fasteners=0.16.3=pyhd3eb1b0_0 + - ffmpeg=4.3=hf484d3e_0 + - filelock=3.0.12=pyhd3eb1b0_1 + - fontconfig=2.13.1=h6c09931_0 + - freetype=2.10.4=h5ab3b9f_0 + - fsspec=2021.6.1=pyhd8ed1ab_0 + - future=0.18.2=py38h578d9bd_3 + - giflib=5.1.4=h14c3975_1 + - gitdb=4.0.7=pyhd8ed1ab_0 + - gitpython=3.1.18=pyhd8ed1ab_0 + - glib=2.68.2=h36276a3_0 + - glob2=0.7=pyhd3eb1b0_0 + - gmp=6.2.1=h2531618_2 + - gnutls=3.6.15=he1e5248_0 + - google-auth=1.30.2=pyh6c4a22f_0 + - google-auth-oauthlib=0.4.1=py_2 + - gql=0.1.0=py_0 + - graphql-core=3.1.5=pyhd8ed1ab_0 + - grpcio=1.38.1=py38hdd6454d_0 + - gst-plugins-base=1.14.0=h8213a91_2 + - gstreamer=1.14.0=h28cd5cc_2 + - hdf5=1.10.6=nompi_h7c3c948_1111 + - icu=58.2=he6710b0_3 + - idna=2.10=pyh9f0ad1d_0 + - ignite=0.4.6=py_0 + - imagecodecs=2021.6.8=py38h581e88b_0 + - imageio=2.9.0=pyhd3eb1b0_0 + - importlib-metadata=4.5.0=py38h578d9bd_0 + - intel-openmp=2021.2.0=h06a4308_610 + - ipykernel=6.3.1=py38he5a9106_0 + - ipython=7.27.0=py38he5a9106_0 + - ipython_genutils=0.2.0=py_1 + - ipywidgets=7.6.5=pyhd8ed1ab_0 + - jedi=0.18.0=py38h578d9bd_2 + - jinja2=2.11.3=pyhd3eb1b0_0 + - joblib=1.0.1=pyhd3eb1b0_0 + - jpeg=9b=h024ee3a_2 + - json5=0.9.5=pyh9f0ad1d_0 + - jsonschema=3.2.0=pyhd8ed1ab_3 + - jupyter=1.0.0=py38h578d9bd_6 + - jupyter_client=7.0.2=pyhd8ed1ab_0 + - jupyter_console=6.4.0=pyhd8ed1ab_0 + - jupyter_core=4.7.1=py38h578d9bd_0 + - jupyter_server=1.10.2=pyhd8ed1ab_0 + - jupyterlab=3.1.9=pyhd8ed1ab_0 + - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0 + - jupyterlab_server=2.7.2=pyhd8ed1ab_0 + - jupyterlab_widgets=1.0.2=pyhd8ed1ab_0 + - jxrlib=1.1=h7b6447c_2 + - kiwisolver=1.3.1=py38h2531618_0 + - krb5=1.19.1=hcc1bbae_0 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.35.1=h7274673_9 + - lerc=2.2.1=h2531618_0 + - libaec=1.0.4=he6710b0_1 + - libarchive=3.4.2=h62408e4_0 + - libcurl=7.77.0=h2574ce0_0 + - libdeflate=1.7=h27cfd23_5 + - libedit=3.1.20191231=he28a2e2_2 + - libev=4.33=h516909a_1 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.3.0=h5101ec6_17 + - libgfortran-ng=7.5.0=ha8ba4b0_17 + - libgfortran4=7.5.0=ha8ba4b0_17 + - libgomp=9.3.0=h5101ec6_17 + - libiconv=1.15=h63c8f33_5 + - libidn2=2.3.1=h27cfd23_0 + - liblief=0.10.1=he6710b0_0 + - libnghttp2=1.43.0=h812cca2_0 + - libpng=1.6.37=hbc83047_0 + - libprotobuf=3.17.2=h780b84a_0 + - libsodium=1.0.18=h36c2ea0_1 + - libssh2=1.9.0=ha56f1ee_6 + - libstdcxx-ng=9.3.0=hd4cf53a_17 + - libtasn1=4.16.0=h27cfd23_0 + - libtiff=4.1.0=h2733197_1 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=1.0.3=h1bed415_2 + - libuv=1.40.0=h7b6447c_0 + - libwebp=1.0.1=h8e7db2f_0 + - libxcb=1.14=h7b6447c_0 + - libxml2=2.9.12=h03d6c58_0 + - libzopfli=1.0.3=he6710b0_0 + - locket=0.2.1=py38h06a4308_1 + - lz4-c=1.9.3=h2531618_0 + - markdown=3.3.4=pyhd8ed1ab_0 + - markupsafe=2.0.1=py38h27cfd23_0 + - matplotlib-inline=0.1.2=pyhd8ed1ab_2 + - mistune=0.8.4=py38h497a2fe_1004 + - mkl=2021.2.0=h06a4308_296 + - mkl-service=2.3.0=py38h27cfd23_1 + - mkl_fft=1.3.0=py38h42c9631_2 + - mkl_random=1.2.1=py38ha9443f7_2 + - msgpack-python=1.0.2=py38hff7bd54_1 + - multidict=5.1.0=py38h497a2fe_1 + - nbclassic=0.3.1=pyhd8ed1ab_1 + - nbclient=0.5.4=pyhd8ed1ab_0 + - nbconvert=6.1.0=py38h578d9bd_0 + - nbformat=5.1.3=pyhd8ed1ab_0 + - ncdu=1.16=h46c0cb4_0 + - ncurses=6.2=he6710b0_1 + - nest-asyncio=1.5.1=pyhd8ed1ab_0 + - nettle=3.7.3=hbbd107a_1 + - networkx=2.5.1=pyhd8ed1ab_0 + - ninja=1.10.2=hff7bd54_1 + - notebook=6.4.3=pyha770c72_0 + - numcodecs=0.8.0=py38h2531618_0 + - numpy=1.20.2=py38h2d18471_0 + - numpy-base=1.20.2=py38hfae3a4d_0 + - nvidia-ml=7.352.0=py_0 + - oauthlib=3.1.1=pyhd8ed1ab_0 + - olefile=0.46=py_0 + - omegaconf=2.1.1=py38h578d9bd_0 + - openh264=2.1.0=hd408876_0 + - openjpeg=2.3.0=h05c96fa_1 + - openssl=1.1.1k=h7f98852_0 + - packaging=20.9=pyh44b312d_0 + - pandoc=2.14.2=h7f98852_0 + - pandocfilters=1.4.2=py_1 + - parso=0.8.2=pyhd8ed1ab_0 + - partd=1.2.0=pyhd3eb1b0_0 + - patchelf=0.12=h2531618_1 + - pathtools=0.1.2=py_1 + - pcre=8.45=h295c915_0 + - pexpect=4.8.0=pyh9f0ad1d_2 + - pickleshare=0.7.5=py_1003 + - pillow=8.2.0=py38he98fc37_0 + - pip=21.1.2=py38h06a4308_0 + - pkginfo=1.7.0=py38h06a4308_0 + - portalocker=2.3.2=py38h578d9bd_0 + - prometheus_client=0.11.0=pyhd8ed1ab_0 + - promise=2.3=py38h578d9bd_3 + - prompt-toolkit=3.0.20=pyha770c72_0 + - prompt_toolkit=3.0.20=hd8ed1ab_0 + - protobuf=3.17.2=py38h709712a_0 + - psutil=5.8.0=py38h27cfd23_1 + - ptyprocess=0.7.0=pyhd3deb0d_0 + - py-lief=0.10.1=py38h403a769_0 + - pyasn1=0.4.8=py_0 + - pybind11=2.6.2=py38h1fd1430_0 + - pybind11-global=2.6.2=py38h1fd1430_0 + - pycocotools=2.0.2=py38hb5d20a5_2 + - pycosat=0.6.3=py38h7b6447c_1 + - pycparser=2.20=pyh9f0ad1d_2 + - pygments=2.10.0=pyhd8ed1ab_0 + - pyjwt=2.1.0=pyhd8ed1ab_0 + - pymongo=3.11.0=py38he6710b0_0 + - pyopenssl=20.0.1=pyhd8ed1ab_0 + - pyparsing=2.4.7=pyh9f0ad1d_0 + - pyqt=5.9.2=py38h05f1152_4 + - pyrsistent=0.17.3=py38h497a2fe_2 + - pysocks=1.7.1=py38h578d9bd_3 + - python=3.8.10=h12debd9_8 + - python-dateutil=2.8.1=pyhd3eb1b0_0 + - python-libarchive-c=2.9=pyhd3eb1b0_1 + - python_abi=3.8=2_cp38 + - pytorch=1.9.0=py3.8_cuda11.1_cudnn8.0.5_0 + - pytz=2021.1=pyhd3eb1b0_0 + - pyu2f=0.1.5=pyhd8ed1ab_0 + - pywavelets=1.1.1=py38h7b6447c_2 + - pyyaml=5.4.1=py38h497a2fe_0 + - pyzmq=19.0.2=py38ha71036d_2 + - qt=5.9.7=h5867ecd_1 + - qtconsole=5.1.1=pyhd8ed1ab_0 + - qtpy=1.11.2=pyhd8ed1ab_0 + - readline=8.1=h27cfd23_0 + - requests=2.25.1=pyhd3deb0d_0 + - requests-oauthlib=1.3.0=pyh9f0ad1d_0 + - requests-unixsocket=0.2.0=py_0 + - ripgrep=12.1.1=0 + - rsa=4.7.2=pyh44b312d_0 + - ruamel_yaml=0.15.100=py38h27cfd23_0 + - scikit-image=0.18.1=py38ha9443f7_0 + - scikit-learn=0.24.2=py38ha9443f7_0 + - scipy=1.6.2=py38had2a1c9_1 + - seaborn=0.11.1=pyhd3eb1b0_0 + - send2trash=1.8.0=pyhd8ed1ab_0 + - sentry-sdk=1.1.0=pyhd8ed1ab_0 + - setuptools=52.0.0=py38h06a4308_0 + - shortuuid=1.0.1=py38h578d9bd_4 + - sip=4.19.13=py38he6710b0_0 + - six=1.16.0=pyhd3eb1b0_0 + - smmap=3.0.5=pyh44b312d_0 + - snappy=1.1.8=he6710b0_0 + - sniffio=1.2.0=py38h578d9bd_1 + - soupsieve=2.2.1=pyhd3eb1b0_0 + - sqlite=3.35.4=hdfb4753_0 + - subprocess32=3.5.4=py_1 + - tabulate=0.8.9=pyhd8ed1ab_0 + - tensorboard=2.4.1=pyhd8ed1ab_0 + - tensorboard-plugin-wit=1.8.0=pyh44b312d_0 + - termcolor=1.1.0=py_2 + - terminado=0.11.1=py38h578d9bd_0 + - testpath=0.5.0=pyhd8ed1ab_0 + - threadpoolctl=2.1.0=pyh5ca1d4c_0 + - tifffile=2020.10.1=py38hdd07704_2 + - tk=8.6.10=hbc83047_0 + - toolz=0.11.1=pyhd3eb1b0_0 + - torchaudio=0.9.0=py38 + - torchvision=0.10.0=py38_cu111 + - tornado=6.1=py38h27cfd23_0 + - tqdm=4.61.1=pyhd8ed1ab_0 + - traitlets=5.1.0=pyhd8ed1ab_0 + - typing-extensions=3.10.0.0=hd8ed1ab_0 + - typing_extensions=3.10.0.0=pyha770c72_0 + - urllib3=1.26.5=pyhd8ed1ab_0 + - wandb=0.10.32=pyhd8ed1ab_0 + - watchdog=0.10.4=py38h578d9bd_0 + - wcwidth=0.2.5=pyh9f0ad1d_2 + - webencodings=0.5.1=py_1 + - websocket-client=0.57.0=py38h578d9bd_4 + - werkzeug=2.0.1=pyhd8ed1ab_0 + - widgetsnbextension=3.5.1=py38h578d9bd_4 + - xtensor=0.23.10=h4bd325d_0 + - xtensor-python=0.25.3=py38hfc89cab_0 + - xtl=0.7.2=h4bd325d_1 + - xz=5.2.5=h7b6447c_0 + - yaml=0.2.5=h516909a_0 + - yarl=1.6.3=py38h497a2fe_1 + - zarr=2.8.1=pyhd3eb1b0_0 + - zeromq=4.3.4=h9c3ff4c_0 + - zfp=0.5.5=h2531618_6 + - zipp=3.4.1=pyhd8ed1ab_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.9=haebb681_0 + - pip: + - cellpose==0.7.4.dev1+g45f1a3c + - csbdeep==0.6.2 + - deepcell-toolbox==0.10.2 + - edt==2.1.1 + - fastremap==1.12.2 + - flow-vis==0.1 + - fonttools==4.29.1 + - fvcore==0.1.5.post20210915 + - h5py==2.10.0 + - imgaug==0.4.0 + - iopath==0.1.9 + - ipympl==0.8.0 + - keras==2.3.1 + - keras-applications==1.0.8 + - keras-preprocessing==1.1.2 + - llvmlite==0.36.0 + - matplotlib==3.5.1 + - mpl-interactions==0.20.1 + - natsort==8.0.0 + - numba==0.53.1 + - opencv-python==4.5.3.56 + - opencv-python-headless==4.5.3.56 + - pandas==1.2.5 + - pep8==1.7.1 + - pyasn1-modules==0.2.8 + - pycodestyle==2.8.0 + - pydeprecate==0.3.1 + - pydot==1.4.2 + - pytorch-lightning==1.5.0rc0 + - pytorch-ranger==0.1.1 + - shapely==1.7.1 + - stardist==0.7.1 + - test-tube==0.7.5 + - tiktorch==20.9.4 + - toml==0.10.2 + - torch-optimizer==0.3.0 + - torchmetrics==0.5.1 + - wheel==0.37.0 + - yacs==0.1.8 diff --git a/scripts/convert_tissuenet.py b/scripts/convert_tissuenet.py new file mode 100644 index 0000000..080c24b --- /dev/null +++ b/scripts/convert_tissuenet.py @@ -0,0 +1,20 @@ +import numpy as np +import zarr +import sys +import os + +tissuenet_folder = sys.argv[1] + +output_file = os.path.join(tissuenet_folder, "tissuenet_v1.0.zarr") +zout = zarr.open(output_file, "w") +splits = {"train": "tissuenet_v1.0_train.npz", + "val": "tissuenet_v1.0_val.npz", + "test": "tissuenet_v1.0_test.npz"} + +for split, fn in splits.items(): + with np.load(os.path.join(tissuenet_folder, fn)) as data: + raw_data = data['X'] + gt_data = data['y'] + w, h = raw_data.shape[1:3] + zout.create_dataset(f"{split}/raw", data=raw_data, chunks=(1, w, h, 1), compression='gzip') + zout.create_dataset(f"{split}/gt", data=gt_data, chunks=(1, w, h, 1), compression='gzip', dtype="int32") diff --git a/scripts/postprocess.py b/scripts/postprocess.py new file mode 100644 index 0000000..84538cc --- /dev/null +++ b/scripts/postprocess.py @@ -0,0 +1,25 @@ +from tqdm import trange +import zarr +import sys +from scipy.ndimage.morphology import distance_transform_edt as dtf +import numpy as np + +inpfn = sys.argv[1] +inputkey = sys.argv[2] + +z = zarr.open(inpfn, "r+") +growd = 3 +th = 6 + +inkey = f"{inputkey}" +inlabels = z[inkey] +outkey = f"{inputkey}_postprocessed" + +zout = z.create_dataset(outkey, shape=inlabels.shape, chunks=(16, -1, -1, -1), overwrite=True) +for t in trange(len(inlabels)): + seg = inlabels[t, ..., 0].copy() + dist = dtf(inlabels[t,...,0] == 0) + mskt = dist < growd + dist = dtf(mskt > 0) + seg[dist < th] = 0 + zout[t, ..., 0] = seg diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 0000000..09f038f --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,8 @@ +coverage +codecov>=2.1 +pytest>=3.0.5 +pytest-cov +pytest-flake8 +flake8 +check-manifest +twine==1.13.0 \ No newline at end of file diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..ad9095d --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,66 @@ +import unittest +import numpy as np + + +class DatamodulesTest(unittest.TestCase): + + def test_initialization(self): + from colocseg.datasets import PatchedDataset + from torch.utils.data.dataset import Dataset + + class NoiseDs(Dataset): + + def __getitem__(self, key): + return np.random.rand(1, 128, 128) + + def __len__(self): + return 10 + + ds = PatchedDataset(NoiseDs(), + output_shape=(100, 100), + positive_radius=15, + density=0.1) + + anchor_samples, refernce_samples = ds.sample_coordinates() + + # make sure all coordinates are in the expected range + assert((refernce_samples >= 0).all()) + assert((refernce_samples < 128).all()) + assert((anchor_samples >= 15).all()) + assert((refernce_samples < 128 - 15).all()) + + def test_supervised_anchors(self): + from colocseg.datasets import SupervisedCoordinateDataset + from torch.utils.data.dataset import Dataset + from skimage.filters import gaussian + from skimage.measure import label + from skimage import data + class CoinDs(Dataset): + + def __getitem__(self, key): + seg = label((gaussian(data.coins(), sigma=3)>0.4)) + return data.coins()[None], seg + + def __len__(self): + return 10 + + ds = SupervisedCoordinateDataset(CoinDs(), + (200, 200), + 2000, + min_size=20, + return_segmentation=True) + + raw, anc, ref, gt = ds[0] + + # uncomment to inspect results visually + # import matplotlib.pyplot as plt + # plt.imshow(gt) + # for i in range(0, len(anc)): + # plt.plot((anc[i, 0], ref[i, 0]), (anc[i, 1], ref[i, 1]), 'ro-') + # plt.show() + + + +if __name__ == '__main__': + # unittest.main() + DatamodulesTest().test_supervised_anchors()