diff --git a/src/neuron_proofreader/machine_learning/image_dataloader.py b/src/neuron_proofreader/machine_learning/image_dataloader.py index 554be34..790969c 100644 --- a/src/neuron_proofreader/machine_learning/image_dataloader.py +++ b/src/neuron_proofreader/machine_learning/image_dataloader.py @@ -20,7 +20,10 @@ from neuron_proofreader.utils import geometry_util, img_util, util -# --- Image Reading --- +# ---------------------------------------------------------------------------- +# Image Class +# ---------------------------------------------------------------------------- + class TensorStoreImage: """ Class that reads images with the TensorStore library. @@ -93,7 +96,10 @@ def shape(self): return self.img.shape -# --- Patch Loading --- +# ---------------------------------------------------------------------------- +# PatchLoader Class +# ---------------------------------------------------------------------------- + class PatchLoader(ABC): """ A class for reading image patches and generating segment masks. @@ -101,7 +107,7 @@ class PatchLoader(ABC): max_voxel_shift = 5 - def __init__(self, graph, img_config, img_path): + def __init__(self, graph, img_config): """ Instantiates a PatchLoader object. @@ -116,7 +122,7 @@ def __init__(self, graph, img_config, img_path): """ self.config = img_config or ImageConfig() self.graph = graph - self.img = TensorStoreImage(img_path) + self.img = TensorStoreImage(img_config.img_path) self.transform = ImageTransforms() if self.config.transform else None # --- Abstract Interface --- @@ -134,18 +140,16 @@ def compute_patch_specs(self): """ pass - @abstractmethod - def create_mask(self): - """ - Abstract method to be implemented by subclasses. - """ - pass - # --- Core Routines --- def annotate_foreground(self, mask, nodes, offset, fill_val=1): visited = set() for i in nodes: + # Check whether to visit voxel voxel_i = self.graph.node_local_voxel(i, offset) + if not img_util.is_contained(voxel_i, mask.shape): + continue + + # Visit neighbors for j in self.graph.neighbors(i): if frozenset({i, j}) not in visited and j in nodes: voxel_j = self.graph.node_local_voxel(j, offset) @@ -163,6 +167,19 @@ def annotate_fragment(self, mask, subgraph, offset, fill_val=1): voxels = geometry_util.make_digital_line(voxel1, voxel2) img_util.annotate_voxels(mask, voxels, fill_val=fill_val) + def create_mask(self, center, shape, node): + # Initializations + offset = img_util.get_offset(center, shape) + depth = np.sqrt(2) * np.max(shape) / (2 * self.graph.anisotropy.min()) + nodes = self.get_foreground_nodes(node, depth) + subgraph = self.graph.rooted_subgraph(node, depth) + + # Annotate mask + mask = np.zeros(shape) + #self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP + self.annotate_fragment(mask, subgraph, offset, fill_val=1) + return mask + def read_image(self, center, shape): """ Reads the image patch specified by the given center and shape. @@ -195,6 +212,11 @@ def adjust_voxel(self, voxel): ) return voxel + def get_foreground_nodes(self, node, radius): + xyz = self.graph.node_xyz[node] + nodes = self.graph.kdtree.query_ball_point(xyz, radius) + return nodes + @staticmethod def stack(img, mask): try: @@ -205,6 +227,10 @@ def stack(img, mask): return patches +# ---------------------------------------------------------------------------- +# PatchLoader Subclasses +# ---------------------------------------------------------------------------- + class DetectionPatchLoader(PatchLoader): # --- Implementation of Abstract Inferface --- @@ -225,24 +251,30 @@ def compute_patch_specs(self, node): voxel = self.adjust_voxel(voxel) return voxel, self.patch_shape - def create_mask(self, center, shape, node): - # Initializations - offset = img_util.get_offset(center, shape) - depth = np.sqrt(2) * np.max(shape) / (2 * self.graph.anisotropy.min()) - nodes = self.get_foreground_nodes(node, depth) - subgraph = self.graph.rooted_subgraph(node, depth) - # Annotate mask - mask = np.zeros(shape) - # self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP - self.annotate_fragment(mask, subgraph, offset, fill_val=1) - return mask +class DetectionBatchLoader(PatchLoader): - # --- Helpers --- - def get_foreground_nodes(self, node, radius): - xyz = self.graph.node_xyz[node] - nodes = self.graph.kdtree.query_ball_point(xyz, radius) - return nodes + def __call__(self, nodes): + # Compute patch info + center, shape = self.compute_patch_specs(nodes) + offset = center - shape // 2 + + # Load patches + img = self.read_image(center, shape) + mask = self.create_mask(center, shape, nodes[len(nodes) // 2]) + return self.stack(img, mask), offset + + def compute_patch_specs(self, nodes): + # Compute bounding box + centers = np.array([self.graph.node_voxel(i) for i in nodes]) + buffer = np.array(self.patch_shape) // 2 + start = centers.min(axis=0) - buffer + end = centers.max(axis=0) + buffer + + # Image patch location + shape = (end - start).astype(int) + center = (start + shape // 2).astype(int) + return center, shape class ProposalPatchLoader(PatchLoader): diff --git a/src/neuron_proofreader/merge_proofreading/merge_inference.py b/src/neuron_proofreader/merge_proofreading/merge_inference.py index ac1b8ce..a821d99 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_inference.py +++ b/src/neuron_proofreader/merge_proofreading/merge_inference.py @@ -12,7 +12,7 @@ from abc import ABC, abstractmethod from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait from torch.nn.functional import sigmoid -from torch.utils.data import IterableDataset +from torch.utils.data import IterableDataset, DataLoader from time import time from tqdm import tqdm @@ -25,8 +25,11 @@ from neuron_proofreader.machine_learning.point_cloud_models import ( subgraph_to_point_cloud, ) +from neuron_proofreader.machine_learning.image_dataloader import ( + DetectionBatchLoader, + DetectionPatchLoader, +) from neuron_proofreader.utils import ( - geometry_util, img_util, ml_util, swc_util, @@ -41,14 +44,16 @@ def __init__( dataset, model, model_path, + batch_size=16, device="cuda", remove_detected_sites=False, threshold=0.4, ): # Instance attributes + self.batch_size = batch_size self.dataset = dataset self.device = device - self.node_preds = np.ones((len(dataset.node_xyz))) * 1e-2 + self.node_preds = np.zeros((len(dataset.node_xyz))) self.patch_shape = dataset.patch_shape self.remove_detected_sites = remove_detected_sites self.threshold = threshold @@ -61,8 +66,9 @@ def __init__( def search_graph(self): # Iterate over dataset t0 = time() + dataloader = DataLoader(self.dataset, batch_size=self.batch_size) pbar = tqdm(total=self.dataset.estimate_iterations()) - for nodes, x_nodes in self.dataset: + for nodes, x_nodes in dataloader: self.node_preds[np.array(nodes)] = self.predict(x_nodes) pbar.update(len(nodes)) @@ -132,7 +138,7 @@ def filter_with_nms(self, merge_sites, likelihoods): ) if iou > 0.35: merge_sites_set.remove(i) - self.node_preds[i] = 1e-2 + self.node_preds[i] = 0 # Populate queue for j in self.dataset.neighbors(i): @@ -185,8 +191,9 @@ def save_results( ): self.save_sites(output_dir) if save_fragments: + self.dataset.graph.node_radius = 10 * np.maximum(self.node_preds, 0.1) fragments_path = os.path.join(output_dir, "fragments.zip") - self.dataset.to_zipped_swcs(fragments_path) + self.dataset.to_zipped_swcs(fragments_path, use_radius=True) # Upload results to S3 (if applicable) if output_prefix_s3: @@ -195,12 +202,15 @@ def save_results( def save_sites(self, output_dir): # Save model predictions - df = pd.DataFrame(columns=["World", "Segment_ID", "Prediction"]) - df["World"] = self.dataset.node_xyz + df = pd.DataFrame( + columns=["World", "Segment_ID", "Prediction", "Degree"] + ) + df["World"] = list(map(tuple, self.dataset.node_xyz)) df["Prediction"] = self.node_preds df["Segment_ID"] = [ self.dataset.node_segment_id(i) for i in self.dataset.nodes ] + df["Degree"] = [self.dataset.degree[i] for i in self.dataset.nodes] df.to_csv(os.path.join(output_dir, "model_predictions.csv")) # Get predicted merge sites @@ -235,51 +245,35 @@ def save_train_dataset(self, output_dir): print("# Fragments Saved:", len(roots)) -# --- Data Handling --- +# --- Datasets --- class GraphDataset(IterableDataset, ABC): def __init__( self, graph, - img_path, - patch_shape, - batch_size=16, - brightness_clip=400, + img_config, is_multimodal=False, min_search_size=0, prefetch=64, - segmentation_path=None, subgraph_radius=100, - use_new_mask=False, ): # Call parent class super().__init__() # Instance attributes - self.batch_size = batch_size - self.brightness_clip = brightness_clip self.distance_traversed = 0 self.graph = graph self.is_multimodal = is_multimodal self.min_size = min_search_size - self.patch_shape = patch_shape + self.patch_shape = img_config.patch_shape self.prefetch = prefetch - self.segmentation_path = segmentation_path self.subgraph_radius = subgraph_radius - self.use_new_mask = use_new_mask # Batch getter if is_multimodal: - self.get_batch = self._get_multimodal_batch + self.generate_inputs = None else: - self.get_batch = self._get_batch - - # Image reader - self.img_reader = img_util.TensorStoreReader(img_path) - if self.segmentation_path: - self.segmentation_reader = img_util.TensorStoreReader( - segmentation_path - ) + self.generate_inputs = self.generate_patches # --- Core routines --- def __iter__(self): @@ -326,70 +320,12 @@ def find_fragments_to_search(self): component_ids.add(self.node_component_id[node]) return component_ids - def get_patch_centers(self, nodes): - patch_centers = [self.node_voxel(i) for i in nodes] - return np.array(patch_centers, dtype=int) - - def get_label_mask(self, nodes, img_shape, offset): - # Read segmentation - if self.use_new_mask: - center = [o + s // 2 for o, s in zip(offset, img_shape)] - segment_mask = self.segmentation_reader.read(center, img_shape) - segment_mask = img_util.remove_small_segments(segment_mask, 1000) - segment_mask = 0.5 * (segment_mask > 0).astype(int) - else: - segment_mask = np.zeros(img_shape) - - # Annotate mask - subgraph = self.get_contained_subgraph(nodes, img_shape, offset) - for i, j in subgraph.edges: - voxel_i = self.node_voxel(i) - offset - voxel_j = self.node_voxel(j) - offset - voxels = geometry_util.make_digital_line(voxel_i, voxel_j) - img_util.annotate_voxels(segment_mask, voxels) - return segment_mask - - def get_contained_subgraph(self, nodes, img_shape, offset): - queue = list(nodes) - visited = set(nodes) - subgraph = nx.Graph() - while queue: - # Visit node - i = queue.pop() - voxel_i = self.node_voxel(i) - offset - if not img_util.is_contained(voxel_i, img_shape, buffer=1): - continue - - # Update queue - for j in self.neighbors(i): - voxel_j = self.node_voxel(j) - offset - if img_util.is_contained(voxel_j, img_shape): - subgraph.add_edge(i, j) - if j not in visited: - queue.append(j) - visited.add(j) - return subgraph - def is_contained(self, node): voxel = self.node_voxel(node) - shape = self.img_reader.shape()[2::] + shape = self.patch_loader.img.shape()[2::] buffer = np.max(self.patch_shape) + 1 return img_util.is_contained(voxel, shape, buffer=buffer) - def read_superchunk(self, nodes): - # Compute bounding box - patch_centers = self.get_patch_centers(nodes) - buffer = 1 + np.array(self.patch_shape) // 2 - start = patch_centers.min(axis=0) - buffer - end = patch_centers.max(axis=0) + buffer - - # Read image - shape = (end - start).astype(int) - center = (start + shape // 2).astype(int) - superchunk = self.img_reader.read(center, shape) - superchunk = np.minimum(superchunk, self.brightness_clip) - return superchunk, start.astype(int) - def is_near_leaf(self, node, threshold=20): # Check if node is branching if self.degree[node] > 2: @@ -417,40 +353,37 @@ def is_node_valid(self, node): is_nonleaf = not self.is_near_leaf(node) return is_contained and is_nonleaf + # --- Helpers --- + def __getattr__(self, name): + return getattr(self.graph, name) -class DenseGraphDataset(GraphDataset): + +class DenseDataset(GraphDataset): + + max_batch_span = 512 def __init__( self, graph, - img_path, - patch_shape, - batch_size=16, - brightness_clip=300, + img_config, is_multimodal=False, min_search_size=0, - prefetch=128, - segmentation_path=None, - step_size=10, + prefetch=64, + step_size=40, subgraph_radius=100, - use_new_mask=False, ): # Call parent class super().__init__( graph, - img_path, - patch_shape, - batch_size=batch_size, - brightness_clip=brightness_clip, + img_config, is_multimodal=is_multimodal, min_search_size=min_search_size, prefetch=prefetch, - segmentation_path=segmentation_path, subgraph_radius=subgraph_radius, - use_new_mask=use_new_mask, ) # Instance attributes + self.patch_loader = DetectionBatchLoader(self.graph, img_config) self.search_mode = "dense" self.step_size = step_size @@ -459,7 +392,7 @@ def _generate_batches_from_component(self, root): def submit_thread(): try: nodes = next(batch_nodes_generator) - thread = executor.submit(self.read_superchunk, nodes) + thread = executor.submit(self.patch_loader, nodes) pending[thread] = nodes except StopIteration: pass @@ -480,7 +413,7 @@ def submit_thread(): # Process completed thread nodes = pending.pop(thread) img, offset = thread.result() - yield self.get_batch(nodes, img, offset) + yield from self.generate_inputs(nodes, img, offset) # Continue submitting threads submit_thread() @@ -511,9 +444,7 @@ def _generate_batch_nodes(self, root): continue # Check whether to yield batch - is_node_far = self.dist(root, j) > 512 - is_batch_full = len(nodes) == self.batch_size - if is_node_far or is_batch_full: + if self.dist(root, j) > self.max_batch_span: # Yield nodes in batch yield np.array(nodes, dtype=int) @@ -533,24 +464,12 @@ def _generate_batch_nodes(self, root): if nodes: yield np.array(nodes, dtype=int) - def _get_batch(self, nodes, img, offset): - # Initializations - label_mask = self.get_label_mask(nodes, img.shape, offset) - patch_centers = self.get_patch_centers(nodes) - offset - - # Populate batch array - batch = np.empty( - ( - len(nodes), - 2, - ) - + self.patch_shape - ) - for i, center in enumerate(patch_centers): + def generate_patches(self, nodes, img, offset): + voxels = np.array([self.node_voxel(i) for i in nodes], dtype=int) + for node, center in zip(nodes, voxels - offset): s = img_util.get_slices(center, self.patch_shape) - batch[i, 0, ...] = img_util.normalize(img[s]) - batch[i, 1, ...] = label_mask[s] - return nodes, torch.tensor(batch, dtype=torch.float) + patch = torch.from_numpy(img[(slice(0, 2), *s)]).float() + yield node, patch def _get_multimodal_batch(self, nodes, img, offset): # Initializations @@ -558,13 +477,7 @@ def _get_multimodal_batch(self, nodes, img, offset): patch_centers = self.get_patch_centers(nodes) - offset # Populate batch array - patches = np.empty( - ( - len(nodes), - 2, - ) - + self.patch_shape - ) + patches = np.empty((len(nodes), 2,) + self.patch_shape) point_clouds = np.empty((len(nodes), 3, 3600), dtype=np.float32) for i, (node, center) in enumerate(zip(nodes, patch_centers)): s = img_util.get_slices(center, self.patch_shape) @@ -584,9 +497,6 @@ def _get_multimodal_batch(self, nodes, img, offset): return nodes, batch # --- Helpers --- - def __getattr__(self, name): - return getattr(self.graph, name) - def estimate_iterations(self): """ Estimates the number of iterations required to search graph. @@ -611,37 +521,7 @@ def estimate_iterations(self): return int(total_cable_length / self.step_size) -class SparseGraphDataset(GraphDataset): - - def __init__( - self, - graph, - img_path, - patch_shape, - batch_size=16, - is_multimodal=False, - min_search_size=0, - prefetch=128, - segmentation_path=None, - subgraph_radius=100, - use_new_mask=False, - ): - # Call parent class - super().__init__( - graph, - img_path, - patch_shape, - batch_size=batch_size, - is_multimodal=is_multimodal, - min_search_size=min_search_size, - prefetch=prefetch, - segmentation_path=segmentation_path, - subgraph_radius=subgraph_radius, - use_new_mask=use_new_mask, - ) - - # Instance attributes - self.search_mode = "branching_points" +class SparseDataset(GraphDataset): def _generate_batches_from_component(self): pass @@ -658,9 +538,7 @@ def _generate_batch_nodes(self, root): patch_centers.append(self.graph.node_voxel(i)) # Check whether to yield batch - is_node_far = self.graph.dist(root, j) > 256 - is_batch_full = len(patch_centers) == self.batch_size - if is_node_far or is_batch_full: + if self.graph.dist(root, j) > self.max_batch_span: # Yield batch metadata patch_centers = np.array(patch_centers, dtype=int) nodes = np.array(nodes, dtype=int) @@ -678,9 +556,6 @@ def _generate_batch_nodes(self, root): root = j # --- Helpers --- - def __getattr__(self, name): - return getattr(self.graph, name) - def estimate_iterations(self): """ Estimates the number of iterations required to search graph. @@ -691,3 +566,30 @@ def estimate_iterations(self): Estimated number of iterations required to search graph. """ return len(self.graph.get_branchings()) + + +class BranchingDataset(GraphDataset): + + def __init__( + self, + graph, + img_config, + is_multimodal=False, + min_search_size=0, + prefetch=64, + step_size=10, + subgraph_radius=100, + ): + # Call parent class + super().__init__( + graph, + img_config, + is_multimodal=is_multimodal, + min_search_size=min_search_size, + prefetch=prefetch, + subgraph_radius=subgraph_radius, + ) + + # Instance attributes + self.patch_loader = DetectionPatchLoader(self.graph, img_config) + self.search_mode = "branching_nodes" diff --git a/src/neuron_proofreader/skeleton_graph.py b/src/neuron_proofreader/skeleton_graph.py index c53beda..e26a3f7 100644 --- a/src/neuron_proofreader/skeleton_graph.py +++ b/src/neuron_proofreader/skeleton_graph.py @@ -531,11 +531,10 @@ def component_to_zipped_swc(self, zipfile, root, use_radius=False): ZipFile object that will store the generated SWC file. root : int Root node of connected component to be written to an SWC file. - ususe_radius : bool, optional + use_radius : bool, optional Indication of whether to preserve radii of nodes or use default radius of 2μm. Default is False. """ - # Subroutines def write_entry(node, parent): """ diff --git a/src/neuron_proofreader/utils/swc_util.py b/src/neuron_proofreader/utils/swc_util.py index 3be81da..920af30 100644 --- a/src/neuron_proofreader/utils/swc_util.py +++ b/src/neuron_proofreader/utils/swc_util.py @@ -109,7 +109,7 @@ def __call__(self, swc_pointer): return self.read_zips(swc_pointer, self.read_zip) # Local SWC files - paths = util.read_paths(swc_pointer, extension=".swc") + paths = util.list_paths(swc_pointer, extension=".swc") if len(paths) > 0: return self.read_swcs(paths) @@ -202,7 +202,7 @@ def read_zips(self, zip_paths, read_fn): pbar = self.manual_progress_bar(len(zip_paths)) with ProcessPoolExecutor() as executor: # Assign processes - futures = {executor.submit(read_fn, path) for path in zip_paths} + futures = {executor.submit(read_fn, path) for path in zip_paths[0:500]} # TEMP # Store results swc_dicts = deque()