Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 59 additions & 27 deletions src/neuron_proofreader/machine_learning/image_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from neuron_proofreader.utils import geometry_util, img_util, util


# --- Image Reading ---
# ----------------------------------------------------------------------------
# Image Class
# ----------------------------------------------------------------------------

class TensorStoreImage:
"""
Class that reads images with the TensorStore library.
Expand Down Expand Up @@ -93,15 +96,18 @@ def shape(self):
return self.img.shape


# --- Patch Loading ---
# ----------------------------------------------------------------------------
# PatchLoader Class
# ----------------------------------------------------------------------------

class PatchLoader(ABC):
"""
A class for reading image patches and generating segment masks.
"""

max_voxel_shift = 5

def __init__(self, graph, img_config, img_path):
def __init__(self, graph, img_config):
"""
Instantiates a PatchLoader object.

Expand All @@ -116,7 +122,7 @@ def __init__(self, graph, img_config, img_path):
"""
self.config = img_config or ImageConfig()
self.graph = graph
self.img = TensorStoreImage(img_path)
self.img = TensorStoreImage(img_config.img_path)
self.transform = ImageTransforms() if self.config.transform else None

# --- Abstract Interface ---
Expand All @@ -134,18 +140,16 @@ def compute_patch_specs(self):
"""
pass

@abstractmethod
def create_mask(self):
"""
Abstract method to be implemented by subclasses.
"""
pass

# --- Core Routines ---
def annotate_foreground(self, mask, nodes, offset, fill_val=1):
visited = set()
for i in nodes:
# Check whether to visit voxel
voxel_i = self.graph.node_local_voxel(i, offset)
if not img_util.is_contained(voxel_i, mask.shape):
continue

# Visit neighbors
for j in self.graph.neighbors(i):
if frozenset({i, j}) not in visited and j in nodes:
voxel_j = self.graph.node_local_voxel(j, offset)
Expand All @@ -163,6 +167,19 @@ def annotate_fragment(self, mask, subgraph, offset, fill_val=1):
voxels = geometry_util.make_digital_line(voxel1, voxel2)
img_util.annotate_voxels(mask, voxels, fill_val=fill_val)

def create_mask(self, center, shape, node):
# Initializations
offset = img_util.get_offset(center, shape)
depth = np.sqrt(2) * np.max(shape) / (2 * self.graph.anisotropy.min())
nodes = self.get_foreground_nodes(node, depth)
subgraph = self.graph.rooted_subgraph(node, depth)

# Annotate mask
mask = np.zeros(shape)
#self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP
self.annotate_fragment(mask, subgraph, offset, fill_val=1)
return mask

def read_image(self, center, shape):
"""
Reads the image patch specified by the given center and shape.
Expand Down Expand Up @@ -195,6 +212,11 @@ def adjust_voxel(self, voxel):
)
return voxel

def get_foreground_nodes(self, node, radius):
xyz = self.graph.node_xyz[node]
nodes = self.graph.kdtree.query_ball_point(xyz, radius)
return nodes

@staticmethod
def stack(img, mask):
try:
Expand All @@ -205,6 +227,10 @@ def stack(img, mask):
return patches


# ----------------------------------------------------------------------------
# PatchLoader Subclasses
# ----------------------------------------------------------------------------

class DetectionPatchLoader(PatchLoader):

# --- Implementation of Abstract Inferface ---
Expand All @@ -225,24 +251,30 @@ def compute_patch_specs(self, node):
voxel = self.adjust_voxel(voxel)
return voxel, self.patch_shape

def create_mask(self, center, shape, node):
# Initializations
offset = img_util.get_offset(center, shape)
depth = np.sqrt(2) * np.max(shape) / (2 * self.graph.anisotropy.min())
nodes = self.get_foreground_nodes(node, depth)
subgraph = self.graph.rooted_subgraph(node, depth)

# Annotate mask
mask = np.zeros(shape)
# self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP
self.annotate_fragment(mask, subgraph, offset, fill_val=1)
return mask
class DetectionBatchLoader(PatchLoader):

# --- Helpers ---
def get_foreground_nodes(self, node, radius):
xyz = self.graph.node_xyz[node]
nodes = self.graph.kdtree.query_ball_point(xyz, radius)
return nodes
def __call__(self, nodes):
# Compute patch info
center, shape = self.compute_patch_specs(nodes)
offset = center - shape // 2

# Load patches
img = self.read_image(center, shape)
mask = self.create_mask(center, shape, nodes[len(nodes) // 2])
return self.stack(img, mask), offset

def compute_patch_specs(self, nodes):
# Compute bounding box
centers = np.array([self.graph.node_voxel(i) for i in nodes])
buffer = np.array(self.patch_shape) // 2
start = centers.min(axis=0) - buffer
end = centers.max(axis=0) + buffer

# Image patch location
shape = (end - start).astype(int)
center = (start + shape // 2).astype(int)
return center, shape


class ProposalPatchLoader(PatchLoader):
Expand Down
Loading
Loading