-
Notifications
You must be signed in to change notification settings - Fork 8
Description
hi shahira, i am writing this issue to serve as notes for myself and other readers. i am documenting how to apply this model to images.
first, i converted the consep model to torchscript for use outside of python (and to serialize the model with weights).
Code to load model trained on consep and save to torchscript
import torch
from model_arch import UnetVggMultihead
model_config = {
"dropout_prob": 0.2,
"initial_pad": 126,
"interpolate": "False",
"conv_init": "he",
"n_classes": 3, # lymphocytes, tumor, stroma
"n_channels": 3,
"n_heads": 4,
"head_classes": [1, 3, 5*3, 21],
}
model = UnetVggMultihead(kwargs=model_config)
weights = torch.load("pretrained_models/mcspat_consep.pth", map_location="cpu")
model.load_state_dict(weights)
script = torch.jit.trace(model, torch.ones(1,3,448,448))
# Freeze the model, and do some other optimizations...
script_opt = torch.jit.optimize_for_inference(script)
# Save.
torch.jit.save(script_opt, "mcspatnet-consep-torchscript-opt.pth")mcspatnet-consep-torchscript-opt.pth.gz
If you download the model, please gunzip it -- I gzipped the file so I could upload it here.
Here are the steps to apply the model to new data...
-
Load the model
-
Load an RGB image of tissue (500x500 pixels at 0.5 microns per pixel). @ShahiraAbousamra - is this correct?
-
Normalize the image to [0, 1] -- usually this means divide the image by 255.
-
Run the image through the model. This gives four output arrays: density map for all nuclei, density map for nuclei with associated classes, (and the next two I'm not quite sure about).
-
Apply sigmoid to the density map for all nuclei (the first output of the model).
-
Threshold the sigmoided density map (
skimage.filters.filters.apply_hysteresis_threshold(density_map, low=0.5, high=0.5)) -
Find connected components in this thresholded density map (
skimage.measure.label(arr)) -
Remove any connected components with an area below a threshold (threshold in the code is 5 pixels).
-
Get the center of each connected component. This point indicates the presence of a nucleus. There is no associated classification yet. That will happen in the next few steps...
centers = np.zeros((500, 500)) # image height and width contours, hierarchy = cv2.findContours(density, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) for idx in range(len(contours)): contour_i = contours[idx] M = cv2.moments(contour_i) if(M['m00'] == 0): continue; # The centroid of this prediction. cx = round(M['m10'] / M['m00']) cy = round(M['m01'] / M['m00'])
-
Softmax the second output of the model (density map for nuclei with associated classes). Softmax in the channel dimension.
-
Take the argmax along the channel dimension to get hard classes at each pixel in the image.
-
For each class...
a. create a binary mask indicating the presence of that class:arr == class_idx
b. multiply this binary mask against thecentersarray from step 9. this will zero out regions that do not have a nucleus.
c. get the center of each point.dots = (arr == class_idx) * centers with_dots = np.where(dots > 0) for idx, _ in enumerate(with_dots[0]): cy = with_dots[0][idx] cx = with_dots[1][idx]
Step 9 gives the location of each nucleus. Step 12c gives the location and class of the nuclei.
Here is an example of the results. The cyan dots show the predicted points.
