From ed64aa022ca6c2be01d8b8eb30ee2c15fe2aedf6 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Mon, 15 Jun 2026 17:02:56 +0000 Subject: [PATCH 1/5] refactor: patch loader in merge inference --- src/neuron_proofreader/configs.py | 1 + .../machine_learning/image_dataloader.py | 86 ++++--- .../merge_proofreading/merge_inference.py | 221 ++++++------------ src/neuron_proofreader/skeleton_graph.py | 3 +- src/neuron_proofreader/utils/swc_util.py | 4 +- 5 files changed, 129 insertions(+), 186 deletions(-) diff --git a/src/neuron_proofreader/configs.py b/src/neuron_proofreader/configs.py index 90be82ef..9a8379aa 100644 --- a/src/neuron_proofreader/configs.py +++ b/src/neuron_proofreader/configs.py @@ -97,6 +97,7 @@ class ImageConfig(Config): """ brightness_clip: int = 400 + img_path: str = None name: str = "image_config" percentiles: Tuple[float, float] = (1, 99.5) patch_shape: Tuple[int, int, int] = (128, 128, 128) diff --git a/src/neuron_proofreader/machine_learning/image_dataloader.py b/src/neuron_proofreader/machine_learning/image_dataloader.py index a0c1834d..36120841 100644 --- a/src/neuron_proofreader/machine_learning/image_dataloader.py +++ b/src/neuron_proofreader/machine_learning/image_dataloader.py @@ -20,7 +20,10 @@ from neuron_proofreader.utils import geometry_util, img_util, util -# --- Image Reading --- +# ---------------------------------------------------------------------------- +# Image Class +# ---------------------------------------------------------------------------- + class TensorStoreImage: """ Class that reads images with the TensorStore library. @@ -93,7 +96,10 @@ def shape(self): return self.img.shape -# --- Patch Loading --- +# ---------------------------------------------------------------------------- +# PatchLoader Class +# ---------------------------------------------------------------------------- + class PatchLoader(ABC): """ A class for reading image patches and generating segment masks. @@ -101,7 +107,7 @@ class PatchLoader(ABC): max_voxel_shift = 5 - def __init__(self, graph, img_config, img_path): + def __init__(self, graph, img_config): """ Instantiates a PatchLoader object. @@ -116,7 +122,7 @@ def __init__(self, graph, img_config, img_path): """ self.config = img_config or ImageConfig() self.graph = graph - self.img = TensorStoreImage(img_path) + self.img = TensorStoreImage(img_config.img_path) self.transform = ImageTransforms() if self.config.transform else None # --- Abstract Interface --- @@ -134,18 +140,16 @@ def compute_patch_specs(self): """ pass - @abstractmethod - def create_mask(self): - """ - Abstract method to be implemented by subclasses. - """ - pass - # --- Core Routines --- def annotate_foreground(self, mask, nodes, offset, fill_val=1): visited = set() for i in nodes: + # Check whether to visit voxel voxel_i = self.graph.node_local_voxel(i, offset) + if not img_util.is_contained(voxel_i, mask.shape): + continue + + # Visit neighbors for j in self.graph.neighbors(i): if frozenset({i, j}) not in visited and j in nodes: voxel_j = self.graph.node_local_voxel(j, offset) @@ -163,6 +167,19 @@ def annotate_fragment(self, mask, subgraph, offset, fill_val=1): voxels = geometry_util.make_digital_line(voxel1, voxel2) img_util.annotate_voxels(mask, voxels, fill_val=fill_val) + def create_mask(self, center, shape, node): + # Initializations + offset = img_util.get_offset(center, shape) + depth = np.sqrt(2) * np.max(shape) / (2 * self.graph.anisotropy.min()) + nodes = self.get_foreground_nodes(node, depth) + subgraph = self.graph.rooted_subgraph(node, depth) + + # Annotate mask + mask = np.zeros(shape) + #self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP + self.annotate_fragment(mask, subgraph, offset, fill_val=1) + return mask + def read_image(self, center, shape): """ Reads the image patch specified by the given center and shape. @@ -195,6 +212,11 @@ def adjust_voxel(self, voxel): ) return voxel + def get_foreground_nodes(self, node, radius): + xyz = self.graph.node_xyz[node] + nodes = self.graph.kdtree.query_ball_point(xyz, radius) + return nodes + @staticmethod def stack(img, mask): try: @@ -205,6 +227,10 @@ def stack(img, mask): return patches +# ---------------------------------------------------------------------------- +# PatchLoader Subclasses +# ---------------------------------------------------------------------------- + class DetectionPatchLoader(PatchLoader): # --- Implementation of Abstract Inferface --- @@ -225,24 +251,30 @@ def compute_patch_specs(self, node): voxel = self.adjust_voxel(voxel) return voxel, self.patch_shape - def create_mask(self, center, shape, node): - # Initializations - offset = img_util.get_offset(center, shape) - depth = np.sqrt(2) * np.max(shape) / (2 * self.graph.anisotropy.min()) - nodes = self.get_foreground_nodes(node, depth) - subgraph = self.graph.rooted_subgraph(node, depth) - # Annotate mask - mask = np.zeros(shape) - #self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP - self.annotate_fragment(mask, subgraph, offset, fill_val=1) - return mask +class DetectionBatchLoader(PatchLoader): - # --- Helpers --- - def get_foreground_nodes(self, node, radius): - xyz = self.graph.node_xyz[node] - nodes = self.graph.kdtree.query_ball_point(xyz, radius) - return nodes + def __call__(self, nodes): + # Compute patch info + center, shape = self.compute_patch_specs(nodes) + offset = center - shape // 2 + + # Load patches + img = self.read_image(center, shape) + mask = self.create_mask(center, shape, nodes[len(nodes) // 2]) + return self.stack(img, mask), offset + + def compute_patch_specs(self, nodes): + # Compute bounding box + centers = np.array([self.graph.node_voxel(i) for i in nodes]) + buffer = np.array(self.patch_shape) // 2 + start = centers.min(axis=0) - buffer + end = centers.max(axis=0) + buffer + + # Image patch location + shape = (end - start).astype(int) + center = (start + shape // 2).astype(int) + return center, shape class ProposalPatchLoader(PatchLoader): diff --git a/src/neuron_proofreader/merge_proofreading/merge_inference.py b/src/neuron_proofreader/merge_proofreading/merge_inference.py index ac1b8ce5..79e6e6cb 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_inference.py +++ b/src/neuron_proofreader/merge_proofreading/merge_inference.py @@ -25,8 +25,11 @@ from neuron_proofreader.machine_learning.point_cloud_models import ( subgraph_to_point_cloud, ) +from neuron_proofreader.machine_learning.image_dataloader import ( + DetectionBatchLoader, + DetectionPatchLoader, +) from neuron_proofreader.utils import ( - geometry_util, img_util, ml_util, swc_util, @@ -48,7 +51,7 @@ def __init__( # Instance attributes self.dataset = dataset self.device = device - self.node_preds = np.ones((len(dataset.node_xyz))) * 1e-2 + self.node_preds = np.zeros((len(dataset.node_xyz))) self.patch_shape = dataset.patch_shape self.remove_detected_sites = remove_detected_sites self.threshold = threshold @@ -132,7 +135,7 @@ def filter_with_nms(self, merge_sites, likelihoods): ) if iou > 0.35: merge_sites_set.remove(i) - self.node_preds[i] = 1e-2 + self.node_preds[i] = 0 # Populate queue for j in self.dataset.neighbors(i): @@ -185,8 +188,9 @@ def save_results( ): self.save_sites(output_dir) if save_fragments: + self.dataset.graph.node_radius = 10 * np.maximum(self.node_preds, 0.1) fragments_path = os.path.join(output_dir, "fragments.zip") - self.dataset.to_zipped_swcs(fragments_path) + self.dataset.to_zipped_swcs(fragments_path, use_radius=True) # Upload results to S3 (if applicable) if output_prefix_s3: @@ -195,12 +199,15 @@ def save_results( def save_sites(self, output_dir): # Save model predictions - df = pd.DataFrame(columns=["World", "Segment_ID", "Prediction"]) - df["World"] = self.dataset.node_xyz + df = pd.DataFrame( + columns=["World", "Segment_ID", "Prediction", "Degree"] + ) + df["World"] = list(map(tuple, self.dataset.node_xyz)) df["Prediction"] = self.node_preds df["Segment_ID"] = [ self.dataset.node_segment_id(i) for i in self.dataset.nodes ] + df["Degree"] = [self.dataset.degree[i] for i in self.dataset.nodes] df.to_csv(os.path.join(output_dir, "model_predictions.csv")) # Get predicted merge sites @@ -235,38 +242,31 @@ def save_train_dataset(self, output_dir): print("# Fragments Saved:", len(roots)) -# --- Data Handling --- +# --- Datasets --- class GraphDataset(IterableDataset, ABC): def __init__( self, graph, - img_path, - patch_shape, + img_config, batch_size=16, - brightness_clip=400, is_multimodal=False, min_search_size=0, prefetch=64, - segmentation_path=None, subgraph_radius=100, - use_new_mask=False, ): # Call parent class super().__init__() # Instance attributes self.batch_size = batch_size - self.brightness_clip = brightness_clip self.distance_traversed = 0 self.graph = graph self.is_multimodal = is_multimodal self.min_size = min_search_size - self.patch_shape = patch_shape + self.patch_shape = img_config.patch_shape self.prefetch = prefetch - self.segmentation_path = segmentation_path self.subgraph_radius = subgraph_radius - self.use_new_mask = use_new_mask # Batch getter if is_multimodal: @@ -274,13 +274,6 @@ def __init__( else: self.get_batch = self._get_batch - # Image reader - self.img_reader = img_util.TensorStoreReader(img_path) - if self.segmentation_path: - self.segmentation_reader = img_util.TensorStoreReader( - segmentation_path - ) - # --- Core routines --- def __iter__(self): # Find fragment IDs to check @@ -326,70 +319,12 @@ def find_fragments_to_search(self): component_ids.add(self.node_component_id[node]) return component_ids - def get_patch_centers(self, nodes): - patch_centers = [self.node_voxel(i) for i in nodes] - return np.array(patch_centers, dtype=int) - - def get_label_mask(self, nodes, img_shape, offset): - # Read segmentation - if self.use_new_mask: - center = [o + s // 2 for o, s in zip(offset, img_shape)] - segment_mask = self.segmentation_reader.read(center, img_shape) - segment_mask = img_util.remove_small_segments(segment_mask, 1000) - segment_mask = 0.5 * (segment_mask > 0).astype(int) - else: - segment_mask = np.zeros(img_shape) - - # Annotate mask - subgraph = self.get_contained_subgraph(nodes, img_shape, offset) - for i, j in subgraph.edges: - voxel_i = self.node_voxel(i) - offset - voxel_j = self.node_voxel(j) - offset - voxels = geometry_util.make_digital_line(voxel_i, voxel_j) - img_util.annotate_voxels(segment_mask, voxels) - return segment_mask - - def get_contained_subgraph(self, nodes, img_shape, offset): - queue = list(nodes) - visited = set(nodes) - subgraph = nx.Graph() - while queue: - # Visit node - i = queue.pop() - voxel_i = self.node_voxel(i) - offset - if not img_util.is_contained(voxel_i, img_shape, buffer=1): - continue - - # Update queue - for j in self.neighbors(i): - voxel_j = self.node_voxel(j) - offset - if img_util.is_contained(voxel_j, img_shape): - subgraph.add_edge(i, j) - if j not in visited: - queue.append(j) - visited.add(j) - return subgraph - def is_contained(self, node): voxel = self.node_voxel(node) - shape = self.img_reader.shape()[2::] + shape = self.patch_loader.img.shape()[2::] buffer = np.max(self.patch_shape) + 1 return img_util.is_contained(voxel, shape, buffer=buffer) - def read_superchunk(self, nodes): - # Compute bounding box - patch_centers = self.get_patch_centers(nodes) - buffer = 1 + np.array(self.patch_shape) // 2 - start = patch_centers.min(axis=0) - buffer - end = patch_centers.max(axis=0) + buffer - - # Read image - shape = (end - start).astype(int) - center = (start + shape // 2).astype(int) - superchunk = self.img_reader.read(center, shape) - superchunk = np.minimum(superchunk, self.brightness_clip) - return superchunk, start.astype(int) - def is_near_leaf(self, node, threshold=20): # Check if node is branching if self.degree[node] > 2: @@ -417,40 +352,39 @@ def is_node_valid(self, node): is_nonleaf = not self.is_near_leaf(node) return is_contained and is_nonleaf + # --- Helpers --- + def __getattr__(self, name): + return getattr(self.graph, name) + + +class DenseDataset(GraphDataset): -class DenseGraphDataset(GraphDataset): + max_batch_span = 512 def __init__( self, graph, - img_path, - patch_shape, + img_config, batch_size=16, - brightness_clip=300, is_multimodal=False, min_search_size=0, - prefetch=128, - segmentation_path=None, + prefetch=64, step_size=10, subgraph_radius=100, - use_new_mask=False, ): # Call parent class super().__init__( graph, - img_path, - patch_shape, + img_config, batch_size=batch_size, - brightness_clip=brightness_clip, is_multimodal=is_multimodal, min_search_size=min_search_size, prefetch=prefetch, - segmentation_path=segmentation_path, subgraph_radius=subgraph_radius, - use_new_mask=use_new_mask, ) # Instance attributes + self.patch_loader = DetectionBatchLoader(self.graph, img_config) self.search_mode = "dense" self.step_size = step_size @@ -459,7 +393,7 @@ def _generate_batches_from_component(self, root): def submit_thread(): try: nodes = next(batch_nodes_generator) - thread = executor.submit(self.read_superchunk, nodes) + thread = executor.submit(self.patch_loader, nodes) pending[thread] = nodes except StopIteration: pass @@ -511,7 +445,7 @@ def _generate_batch_nodes(self, root): continue # Check whether to yield batch - is_node_far = self.dist(root, j) > 512 + is_node_far = self.dist(root, j) > self.max_batch_span is_batch_full = len(nodes) == self.batch_size if is_node_far or is_batch_full: # Yield nodes in batch @@ -534,22 +468,11 @@ def _generate_batch_nodes(self, root): yield np.array(nodes, dtype=int) def _get_batch(self, nodes, img, offset): - # Initializations - label_mask = self.get_label_mask(nodes, img.shape, offset) - patch_centers = self.get_patch_centers(nodes) - offset - - # Populate batch array - batch = np.empty( - ( - len(nodes), - 2, - ) - + self.patch_shape - ) - for i, center in enumerate(patch_centers): + batch = np.empty((len(nodes), 2,) + self.patch_shape) + voxels = np.array([self.node_voxel(i) for i in nodes], dtype=int) + for i, center in enumerate(voxels - offset): s = img_util.get_slices(center, self.patch_shape) - batch[i, 0, ...] = img_util.normalize(img[s]) - batch[i, 1, ...] = label_mask[s] + batch[i] = img[(slice(0, 2), *s)] return nodes, torch.tensor(batch, dtype=torch.float) def _get_multimodal_batch(self, nodes, img, offset): @@ -558,13 +481,7 @@ def _get_multimodal_batch(self, nodes, img, offset): patch_centers = self.get_patch_centers(nodes) - offset # Populate batch array - patches = np.empty( - ( - len(nodes), - 2, - ) - + self.patch_shape - ) + patches = np.empty((len(nodes), 2,) + self.patch_shape) point_clouds = np.empty((len(nodes), 3, 3600), dtype=np.float32) for i, (node, center) in enumerate(zip(nodes, patch_centers)): s = img_util.get_slices(center, self.patch_shape) @@ -584,9 +501,6 @@ def _get_multimodal_batch(self, nodes, img, offset): return nodes, batch # --- Helpers --- - def __getattr__(self, name): - return getattr(self.graph, name) - def estimate_iterations(self): """ Estimates the number of iterations required to search graph. @@ -611,37 +525,7 @@ def estimate_iterations(self): return int(total_cable_length / self.step_size) -class SparseGraphDataset(GraphDataset): - - def __init__( - self, - graph, - img_path, - patch_shape, - batch_size=16, - is_multimodal=False, - min_search_size=0, - prefetch=128, - segmentation_path=None, - subgraph_radius=100, - use_new_mask=False, - ): - # Call parent class - super().__init__( - graph, - img_path, - patch_shape, - batch_size=batch_size, - is_multimodal=is_multimodal, - min_search_size=min_search_size, - prefetch=prefetch, - segmentation_path=segmentation_path, - subgraph_radius=subgraph_radius, - use_new_mask=use_new_mask, - ) - - # Instance attributes - self.search_mode = "branching_points" +class SparseDataset(GraphDataset): def _generate_batches_from_component(self): pass @@ -678,9 +562,6 @@ def _generate_batch_nodes(self, root): root = j # --- Helpers --- - def __getattr__(self, name): - return getattr(self.graph, name) - def estimate_iterations(self): """ Estimates the number of iterations required to search graph. @@ -691,3 +572,33 @@ def estimate_iterations(self): Estimated number of iterations required to search graph. """ return len(self.graph.get_branchings()) + + +class BranchingDataset(GraphDataset): + + def __init__( + self, + graph, + img_config, + batch_size=16, + is_multimodal=False, + min_search_size=0, + prefetch=128, + step_size=10, + subgraph_radius=100, + ): + # Call parent class + super().__init__( + graph, + img_config, + batch_size=batch_size, + is_multimodal=is_multimodal, + min_search_size=min_search_size, + prefetch=prefetch, + subgraph_radius=subgraph_radius, + ) + + # Instance attributes + self.patch_loader = DetectionPatchLoader(self.graph, img_config) + self.search_mode = "branching_nodes" + diff --git a/src/neuron_proofreader/skeleton_graph.py b/src/neuron_proofreader/skeleton_graph.py index ca47684f..7904f859 100644 --- a/src/neuron_proofreader/skeleton_graph.py +++ b/src/neuron_proofreader/skeleton_graph.py @@ -539,11 +539,10 @@ def component_to_zipped_swc(self, zipfile, root, use_radius=False): ZipFile object that will store the generated SWC file. root : int Root node of connected component to be written to an SWC file. - ususe_radius : bool, optional + use_radius : bool, optional Indication of whether to preserve radii of nodes or use default radius of 2μm. Default is False. """ - # Subroutines def write_entry(node, parent): """ diff --git a/src/neuron_proofreader/utils/swc_util.py b/src/neuron_proofreader/utils/swc_util.py index 3be81da2..920af307 100644 --- a/src/neuron_proofreader/utils/swc_util.py +++ b/src/neuron_proofreader/utils/swc_util.py @@ -109,7 +109,7 @@ def __call__(self, swc_pointer): return self.read_zips(swc_pointer, self.read_zip) # Local SWC files - paths = util.read_paths(swc_pointer, extension=".swc") + paths = util.list_paths(swc_pointer, extension=".swc") if len(paths) > 0: return self.read_swcs(paths) @@ -202,7 +202,7 @@ def read_zips(self, zip_paths, read_fn): pbar = self.manual_progress_bar(len(zip_paths)) with ProcessPoolExecutor() as executor: # Assign processes - futures = {executor.submit(read_fn, path) for path in zip_paths} + futures = {executor.submit(read_fn, path) for path in zip_paths[0:500]} # TEMP # Store results swc_dicts = deque() From c5ee363cd19f6c92154d6d9efd0513e8e7cd72f8 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 16 Jun 2026 02:06:30 +0000 Subject: [PATCH 2/5] refactor: simplified batch loading --- .../merge_proofreading/merge_inference.py | 43 ++++++++----------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/src/neuron_proofreader/merge_proofreading/merge_inference.py b/src/neuron_proofreader/merge_proofreading/merge_inference.py index 79e6e6cb..72bd37c9 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_inference.py +++ b/src/neuron_proofreader/merge_proofreading/merge_inference.py @@ -12,7 +12,7 @@ from abc import ABC, abstractmethod from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait from torch.nn.functional import sigmoid -from torch.utils.data import IterableDataset +from torch.utils.data import IterableDataset, DataLoader from time import time from tqdm import tqdm @@ -44,11 +44,13 @@ def __init__( dataset, model, model_path, + batch_size=16, device="cuda", remove_detected_sites=False, threshold=0.4, ): # Instance attributes + self.batch_size = batch_size self.dataset = dataset self.device = device self.node_preds = np.zeros((len(dataset.node_xyz))) @@ -64,8 +66,13 @@ def __init__( def search_graph(self): # Iterate over dataset t0 = time() + dataloader = DataLoader( + self.dataset, + batch_size=self.batch_size, + num_workers=0, + ) pbar = tqdm(total=self.dataset.estimate_iterations()) - for nodes, x_nodes in self.dataset: + for nodes, x_nodes in dataloader: #self.dataset: self.node_preds[np.array(nodes)] = self.predict(x_nodes) pbar.update(len(nodes)) @@ -249,7 +256,6 @@ def __init__( self, graph, img_config, - batch_size=16, is_multimodal=False, min_search_size=0, prefetch=64, @@ -259,7 +265,6 @@ def __init__( super().__init__() # Instance attributes - self.batch_size = batch_size self.distance_traversed = 0 self.graph = graph self.is_multimodal = is_multimodal @@ -270,9 +275,9 @@ def __init__( # Batch getter if is_multimodal: - self.get_batch = self._get_multimodal_batch + self.generate_inputs = None else: - self.get_batch = self._get_batch + self.generate_inputs = self.generate_patches # --- Core routines --- def __iter__(self): @@ -365,7 +370,6 @@ def __init__( self, graph, img_config, - batch_size=16, is_multimodal=False, min_search_size=0, prefetch=64, @@ -376,7 +380,6 @@ def __init__( super().__init__( graph, img_config, - batch_size=batch_size, is_multimodal=is_multimodal, min_search_size=min_search_size, prefetch=prefetch, @@ -414,7 +417,7 @@ def submit_thread(): # Process completed thread nodes = pending.pop(thread) img, offset = thread.result() - yield self.get_batch(nodes, img, offset) + yield from self.generate_inputs(nodes, img, offset) # Continue submitting threads submit_thread() @@ -445,9 +448,7 @@ def _generate_batch_nodes(self, root): continue # Check whether to yield batch - is_node_far = self.dist(root, j) > self.max_batch_span - is_batch_full = len(nodes) == self.batch_size - if is_node_far or is_batch_full: + if self.dist(root, j) > self.max_batch_span: # Yield nodes in batch yield np.array(nodes, dtype=int) @@ -467,13 +468,12 @@ def _generate_batch_nodes(self, root): if nodes: yield np.array(nodes, dtype=int) - def _get_batch(self, nodes, img, offset): - batch = np.empty((len(nodes), 2,) + self.patch_shape) + def generate_patches(self, nodes, img, offset): voxels = np.array([self.node_voxel(i) for i in nodes], dtype=int) - for i, center in enumerate(voxels - offset): + for node, center in zip(nodes, voxels - offset): s = img_util.get_slices(center, self.patch_shape) - batch[i] = img[(slice(0, 2), *s)] - return nodes, torch.tensor(batch, dtype=torch.float) + patch = torch.from_numpy(img[(slice(0, 2), *s)]).float() + yield node, patch def _get_multimodal_batch(self, nodes, img, offset): # Initializations @@ -542,9 +542,7 @@ def _generate_batch_nodes(self, root): patch_centers.append(self.graph.node_voxel(i)) # Check whether to yield batch - is_node_far = self.graph.dist(root, j) > 256 - is_batch_full = len(patch_centers) == self.batch_size - if is_node_far or is_batch_full: + if self.graph.dist(root, j) > 256: # Yield batch metadata patch_centers = np.array(patch_centers, dtype=int) nodes = np.array(nodes, dtype=int) @@ -580,10 +578,9 @@ def __init__( self, graph, img_config, - batch_size=16, is_multimodal=False, min_search_size=0, - prefetch=128, + prefetch=64, step_size=10, subgraph_radius=100, ): @@ -591,7 +588,6 @@ def __init__( super().__init__( graph, img_config, - batch_size=batch_size, is_multimodal=is_multimodal, min_search_size=min_search_size, prefetch=prefetch, @@ -601,4 +597,3 @@ def __init__( # Instance attributes self.patch_loader = DetectionPatchLoader(self.graph, img_config) self.search_mode = "branching_nodes" - From 3012a473013a02821fe70b05fa8a394590ffca62 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 16 Jun 2026 02:11:15 +0000 Subject: [PATCH 3/5] removed comment --- .../merge_proofreading/merge_inference.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/neuron_proofreader/merge_proofreading/merge_inference.py b/src/neuron_proofreader/merge_proofreading/merge_inference.py index 72bd37c9..9463827b 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_inference.py +++ b/src/neuron_proofreader/merge_proofreading/merge_inference.py @@ -66,13 +66,9 @@ def __init__( def search_graph(self): # Iterate over dataset t0 = time() - dataloader = DataLoader( - self.dataset, - batch_size=self.batch_size, - num_workers=0, - ) + dataloader = DataLoader(self.dataset, batch_size=self.batch_size) pbar = tqdm(total=self.dataset.estimate_iterations()) - for nodes, x_nodes in dataloader: #self.dataset: + for nodes, x_nodes in dataloader: self.node_preds[np.array(nodes)] = self.predict(x_nodes) pbar.update(len(nodes)) From 026e2e110f5d8334298cfddfa68a1832f7ff5193 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 16 Jun 2026 03:42:36 +0000 Subject: [PATCH 4/5] refactor: detection speedup --- .../merge_proofreading/merge_detection.py | 230 +++++++ .../merge_proofreading/merge_inference.py | 595 ------------------ .../merge_proofreading/search_datasets.py | 319 ++++++++++ 3 files changed, 549 insertions(+), 595 deletions(-) create mode 100644 src/neuron_proofreader/merge_proofreading/merge_detection.py delete mode 100644 src/neuron_proofreader/merge_proofreading/merge_inference.py create mode 100644 src/neuron_proofreader/merge_proofreading/search_datasets.py diff --git a/src/neuron_proofreader/merge_proofreading/merge_detection.py b/src/neuron_proofreader/merge_proofreading/merge_detection.py new file mode 100644 index 00000000..280009be --- /dev/null +++ b/src/neuron_proofreader/merge_proofreading/merge_detection.py @@ -0,0 +1,230 @@ +""" +Created on Wed June 15 16:00:00 2026 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Code for detecting merge mistakes on skeletons generated from an automated +image segmentation. + +""" + +from torch.nn.functional import sigmoid +from torch.utils.data import DataLoader +from time import time +from tqdm import tqdm + +import numpy as np +import os +import pandas as pd +import torch + +from neuron_proofreader.utils import img_util, ml_util, swc_util, util + + +class MergeDetector: + + def __init__( + self, + dataset, + model, + model_path, + batch_size=16, + device="cuda", + remove_detected_sites=False, + threshold=0.4, + ): + # Instance attributes + self.batch_size = batch_size + self.dataset = dataset + self.device = device + self.node_preds = np.zeros((len(dataset.node_xyz))) + self.patch_shape = dataset.patch_shape + self.remove_detected_sites = remove_detected_sites + self.threshold = threshold + + # Load model + self.model = model + ml_util.load_model(model, model_path, device=self.device) + + # --- Core routines --- + def search_graph(self): + # Iterate over dataset + t0 = time() + dataloader = DataLoader(self.dataset, batch_size=self.batch_size) + pbar = tqdm(total=self.dataset.estimate_iterations()) + for nodes, x_nodes in dataloader: + self.node_preds[np.array(nodes)] = self.predict(x_nodes) + pbar.update(len(nodes)) + + # Non-maximum suppression of detected sites + merge_sites = np.where(self.node_preds > self.threshold)[0] + likelihoods = self.node_preds[merge_sites] + merge_sites = self.filter_with_nms(merge_sites, likelihoods) + + # Report results + rate = self.dataset.distance_traversed / (time() - t0) + print("\n# Detected Merge Sites:", len(merge_sites)) + print(f"Distance Traversed: {self.dataset.distance_traversed:.2f}μm") + print(f"Merge Proofreading Rate: {rate:.2f}μm/s") + + # Remove merge mistakes (optional) + if self.remove_detected_sites: + pass + return merge_sites + + def predict(self, x): + """ + Predicts merge site likelihoods for the given node features. + + Parameters + ---------- + x : torch.Tensor + Node features with shape Nx2xMxMxM, where N is the number of nodes + and MxMxM is the patch shape. + + Returns + ------- + numpy.ndarray + Predicted merge site likelihoods. + """ + with torch.inference_mode(): + x = x.to(self.device) + y = sigmoid(self.model(x)) + return np.squeeze(ml_util.to_cpu(y, to_numpy=True), axis=1) + + def filter_with_nms(self, merge_sites, likelihoods): + # Sort by confidence + merge_sites = [merge_sites[i] for i in np.argsort(likelihoods)] + + # NMS + merge_sites_set = set(merge_sites) + filtered_merge_sites = set() + while merge_sites: + # Local max + root = merge_sites.pop() + xyz_root = self.dataset.node_xyz[root] + if root in merge_sites_set: + filtered_merge_sites.add(root) + merge_sites_set.remove(root) + else: + continue + + # Suppress neighborhood + queue = [(root, 0)] + visited = set([root]) + while queue: + # Visit node + i, dist_i = queue.pop() + if i in merge_sites_set: + xyz_i = self.dataset.node_xyz[i] + iou = img_util.compute_iou3d( + xyz_i, xyz_root, self.patch_shape, self.patch_shape + ) + if iou > 0.35: + merge_sites_set.remove(i) + self.node_preds[i] = 0 + + # Populate queue + for j in self.dataset.neighbors(i): + dist_j = dist_i + self.dataset.dist(i, j) + if j not in visited and dist_j < self.patch_shape[0]: + queue.append((j, dist_j)) + visited.add(j) + return filtered_merge_sites + + def remove_merge_sites(self, merge_site_nodes, max_depth=10): + rm_nodes = set() + for root in tqdm(merge_site_nodes, desc="Remove Merge Sites"): + # Extract neighborhood + root = self.dataset.find_nearby_branching_node(root) + nbhd = self.dataset.nodes_within_distance(root, max_depth) + + # Check for branching node in neighborhood + for i in list(nbhd): + if i != root and self.dataset.degree[i] >= 3: + nbhd_i = self.dataset.nodes_within_distance(root, 8) + nbhd.extend(nbhd_i) + + # Add nodes to removal list + rm_nodes.update(set(nbhd)) + + # Update graph + self.dataset.remove_nodes(rm_nodes) + print("# Nodes Deleted:", len(rm_nodes)) + + # --- Helpers --- + def get_detected_sites(self, threshold): + nodes = np.where(self.node_preds >= threshold)[0] + return [self.dataset.node_xyz[i] for i in nodes] + + def save_parameters(self, output_dir): + json_path = os.path.join(output_dir, "detection_parameters.json") + parameters = { + "accept_threshold": self.threshold, + "is_multimodal": self.dataset.is_multimodal, + "min_search_size": self.dataset.min_size, + "patch_shape": self.patch_shape, + "remove_detected_sites": self.remove_detected_sites, + "search_mode": self.dataset.search_mode, + "subgraph_radius": self.dataset.subgraph_radius, + } + util.write_json(json_path, parameters) + + def save_results( + self, output_dir, output_prefix_s3=None, save_fragments=True + ): + self.save_sites(output_dir) + if save_fragments: + self.dataset.graph.node_radius = 10 * np.maximum(self.node_preds, 0.1) + fragments_path = os.path.join(output_dir, "fragments.zip") + self.dataset.to_zipped_swcs(fragments_path, use_radius=True) + + # Upload results to S3 (if applicable) + if output_prefix_s3: + bucket_name, prefix = util.parse_cloud_path(output_prefix_s3) + util.upload_dir_to_s3(output_dir, bucket_name, prefix) + + def save_sites(self, output_dir): + # Save model predictions + df = pd.DataFrame( + columns=["World", "Segment_ID", "Prediction", "Degree"] + ) + df["World"] = list(map(tuple, self.dataset.node_xyz)) + df["Prediction"] = self.node_preds + df["Segment_ID"] = [ + self.dataset.node_segment_id(i) for i in self.dataset.nodes + ] + df["Degree"] = [self.dataset.degree[i] for i in self.dataset.nodes] + df.to_csv(os.path.join(output_dir, "model_predictions.csv")) + + # Get predicted merge sites + nodes = np.where(self.node_preds >= self.threshold)[0] + detected_sites = [self.dataset.node_xyz[i] for i in nodes] + print("# Sites Saved:", len(nodes)) + + # Save predicted merge sites + zip_path = os.path.join(output_dir, "detected_sites.zip") + swc_util.write_points( + zip_path, + detected_sites, + color="1.0 0.0 0.0", + prefix="merge-site", + radius=10, + ) + + def save_train_dataset(self, output_dir): + # Extract fragments to save + roots = list() + visited_ids = set() + for i in np.where(self.node_preds >= self.threshold)[0]: + cc_id = self.dataset.node_component_id[i] + if cc_id not in visited_ids: + roots.append([i]) + visited_ids.add(cc_id) + + # Save fragments + zip_path = os.path.join(output_dir, "fragments.zip") + self.dataset._batch_to_zipped_swcs(roots, zip_path, False) + self.save_sites(output_dir) + print("# Fragments Saved:", len(roots)) diff --git a/src/neuron_proofreader/merge_proofreading/merge_inference.py b/src/neuron_proofreader/merge_proofreading/merge_inference.py deleted file mode 100644 index a821d990..00000000 --- a/src/neuron_proofreader/merge_proofreading/merge_inference.py +++ /dev/null @@ -1,595 +0,0 @@ -""" -Created on Wed August 4 16:00:00 2025 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Code for detecting merge mistakes on skeletons generated from an automated -image segmentation. - -""" - -from abc import ABC, abstractmethod -from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait -from torch.nn.functional import sigmoid -from torch.utils.data import IterableDataset, DataLoader -from time import time -from tqdm import tqdm - -import networkx as nx -import numpy as np -import os -import pandas as pd -import torch - -from neuron_proofreader.machine_learning.point_cloud_models import ( - subgraph_to_point_cloud, -) -from neuron_proofreader.machine_learning.image_dataloader import ( - DetectionBatchLoader, - DetectionPatchLoader, -) -from neuron_proofreader.utils import ( - img_util, - ml_util, - swc_util, - util, -) - - -class MergeDetector: - - def __init__( - self, - dataset, - model, - model_path, - batch_size=16, - device="cuda", - remove_detected_sites=False, - threshold=0.4, - ): - # Instance attributes - self.batch_size = batch_size - self.dataset = dataset - self.device = device - self.node_preds = np.zeros((len(dataset.node_xyz))) - self.patch_shape = dataset.patch_shape - self.remove_detected_sites = remove_detected_sites - self.threshold = threshold - - # Load model - self.model = model - ml_util.load_model(model, model_path, device=self.device) - - # --- Core routines --- - def search_graph(self): - # Iterate over dataset - t0 = time() - dataloader = DataLoader(self.dataset, batch_size=self.batch_size) - pbar = tqdm(total=self.dataset.estimate_iterations()) - for nodes, x_nodes in dataloader: - self.node_preds[np.array(nodes)] = self.predict(x_nodes) - pbar.update(len(nodes)) - - # Non-maximum suppression of detected sites - merge_sites = np.where(self.node_preds > self.threshold)[0] - likelihoods = self.node_preds[merge_sites] - merge_sites = self.filter_with_nms(merge_sites, likelihoods) - - # Report results - rate = self.dataset.distance_traversed / (time() - t0) - print("\n# Detected Merge Sites:", len(merge_sites)) - print(f"Distance Traversed: {self.dataset.distance_traversed:.2f}μm") - print(f"Merge Proofreading Rate: {rate:.2f}μm/s") - - # Remove merge mistakes (optional) - if self.remove_detected_sites: - pass - return merge_sites - - def predict(self, x): - """ - Predicts merge site likelihoods for the given node features. - - Parameters - ---------- - x : torch.Tensor - Node features with shape Nx2xMxMxM, where N is the number of nodes - and MxMxM is the patch shape. - - Returns - ------- - numpy.ndarray - Predicted merge site likelihoods. - """ - with torch.inference_mode(): - x = x.to(self.device) - y = sigmoid(self.model(x)) - return np.squeeze(ml_util.to_cpu(y, to_numpy=True), axis=1) - - def filter_with_nms(self, merge_sites, likelihoods): - # Sort by confidence - merge_sites = [merge_sites[i] for i in np.argsort(likelihoods)] - - # NMS - merge_sites_set = set(merge_sites) - filtered_merge_sites = set() - while merge_sites: - # Local max - root = merge_sites.pop() - xyz_root = self.dataset.node_xyz[root] - if root in merge_sites_set: - filtered_merge_sites.add(root) - merge_sites_set.remove(root) - else: - continue - - # Suppress neighborhood - queue = [(root, 0)] - visited = set([root]) - while queue: - # Visit node - i, dist_i = queue.pop() - if i in merge_sites_set: - xyz_i = self.dataset.node_xyz[i] - iou = img_util.compute_iou3d( - xyz_i, xyz_root, self.patch_shape, self.patch_shape - ) - if iou > 0.35: - merge_sites_set.remove(i) - self.node_preds[i] = 0 - - # Populate queue - for j in self.dataset.neighbors(i): - dist_j = dist_i + self.dataset.dist(i, j) - if j not in visited and dist_j < self.patch_shape[0]: - queue.append((j, dist_j)) - visited.add(j) - return filtered_merge_sites - - def remove_merge_sites(self, merge_site_nodes, max_depth=10): - rm_nodes = set() - for root in tqdm(merge_site_nodes, desc="Remove Merge Sites"): - # Extract neighborhood - root = self.dataset.find_nearby_branching_node(root) - nbhd = self.dataset.nodes_within_distance(root, max_depth) - - # Check for branching node in neighborhood - for i in list(nbhd): - if i != root and self.dataset.degree[i] >= 3: - nbhd_i = self.dataset.nodes_within_distance(root, 8) - nbhd.extend(nbhd_i) - - # Add nodes to removal list - rm_nodes.update(set(nbhd)) - - # Update graph - self.dataset.remove_nodes(rm_nodes) - print("# Nodes Deleted:", len(rm_nodes)) - - # --- Helpers --- - def get_detected_sites(self, threshold): - nodes = np.where(self.node_preds >= threshold)[0] - return [self.dataset.node_xyz[i] for i in nodes] - - def save_parameters(self, output_dir): - json_path = os.path.join(output_dir, "detection_parameters.json") - parameters = { - "accept_threshold": self.threshold, - "is_multimodal": self.dataset.is_multimodal, - "min_search_size": self.dataset.min_size, - "patch_shape": self.patch_shape, - "remove_detected_sites": self.remove_detected_sites, - "search_mode": self.dataset.search_mode, - "subgraph_radius": self.dataset.subgraph_radius, - } - util.write_json(json_path, parameters) - - def save_results( - self, output_dir, output_prefix_s3=None, save_fragments=True - ): - self.save_sites(output_dir) - if save_fragments: - self.dataset.graph.node_radius = 10 * np.maximum(self.node_preds, 0.1) - fragments_path = os.path.join(output_dir, "fragments.zip") - self.dataset.to_zipped_swcs(fragments_path, use_radius=True) - - # Upload results to S3 (if applicable) - if output_prefix_s3: - bucket_name, prefix = util.parse_cloud_path(output_prefix_s3) - util.upload_dir_to_s3(output_dir, bucket_name, prefix) - - def save_sites(self, output_dir): - # Save model predictions - df = pd.DataFrame( - columns=["World", "Segment_ID", "Prediction", "Degree"] - ) - df["World"] = list(map(tuple, self.dataset.node_xyz)) - df["Prediction"] = self.node_preds - df["Segment_ID"] = [ - self.dataset.node_segment_id(i) for i in self.dataset.nodes - ] - df["Degree"] = [self.dataset.degree[i] for i in self.dataset.nodes] - df.to_csv(os.path.join(output_dir, "model_predictions.csv")) - - # Get predicted merge sites - nodes = np.where(self.node_preds >= self.threshold)[0] - detected_sites = [self.dataset.node_xyz[i] for i in nodes] - print("# Sites Saved:", len(nodes)) - - # Save predicted merge sites - zip_path = os.path.join(output_dir, "detected_sites.zip") - swc_util.write_points( - zip_path, - detected_sites, - color="1.0 0.0 0.0", - prefix="merge-site", - radius=10, - ) - - def save_train_dataset(self, output_dir): - # Extract fragments to save - roots = list() - visited_ids = set() - for i in np.where(self.node_preds >= self.threshold)[0]: - cc_id = self.dataset.node_component_id[i] - if cc_id not in visited_ids: - roots.append([i]) - visited_ids.add(cc_id) - - # Save fragments - zip_path = os.path.join(output_dir, "fragments.zip") - self.dataset._batch_to_zipped_swcs(roots, zip_path, False) - self.save_sites(output_dir) - print("# Fragments Saved:", len(roots)) - - -# --- Datasets --- -class GraphDataset(IterableDataset, ABC): - - def __init__( - self, - graph, - img_config, - is_multimodal=False, - min_search_size=0, - prefetch=64, - subgraph_radius=100, - ): - # Call parent class - super().__init__() - - # Instance attributes - self.distance_traversed = 0 - self.graph = graph - self.is_multimodal = is_multimodal - self.min_size = min_search_size - self.patch_shape = img_config.patch_shape - self.prefetch = prefetch - self.subgraph_radius = subgraph_radius - - # Batch getter - if is_multimodal: - self.generate_inputs = None - else: - self.generate_inputs = self.generate_patches - - # --- Core routines --- - def __iter__(self): - # Find fragment IDs to check - valid_ids = self.find_fragments_to_search() - - # Search graph - visited_ids = set() - for u in self.graph.leaf_nodes(): - component_id = self.node_component_id[u] - if component_id not in visited_ids and component_id in valid_ids: - visited_ids.add(component_id) - yield from self._generate_batches_from_component(u) - - @abstractmethod - def _generate_batches_from_component(self, root): - """ - Abstract method to be implemented by subclasses. - """ - pass - - @abstractmethod - def _generate_batch_nodes(self, root): - """ - Abstract method to be implemented by subclasses. - """ - - # --- Helpers --- - @abstractmethod - def estimate_iterations(self): - pass - - def find_fragments_to_search(self): - component_ids = set() - for nodes in nx.connected_components(self.graph): - # Compute path length - node = util.sample_once(list(nodes)) - length = self.graph.cable_length( - max_depth=self.min_size, root=node - ) - - # Check if path length satisfies threshold - if length > self.min_size: - component_ids.add(self.node_component_id[node]) - return component_ids - - def is_contained(self, node): - voxel = self.node_voxel(node) - shape = self.patch_loader.img.shape()[2::] - buffer = np.max(self.patch_shape) + 1 - return img_util.is_contained(voxel, shape, buffer=buffer) - - def is_near_leaf(self, node, threshold=20): - # Check if node is branching - if self.degree[node] > 2: - return False - - # Search neighborhood - queue = [(node, 0)] - visited = {node} - while len(queue) > 0: - # Visit node - i, dist_i = queue.pop() - if self.degree[i] == 1: - return True - - # Update queue - for j in self.neighbors(i): - dist_j = dist_i + self.dist(i, j) - if j not in visited and dist_j < threshold: - queue.append((j, dist_j)) - visited.add(j) - return False - - def is_node_valid(self, node): - is_contained = self.is_contained(node) - is_nonleaf = not self.is_near_leaf(node) - return is_contained and is_nonleaf - - # --- Helpers --- - def __getattr__(self, name): - return getattr(self.graph, name) - - -class DenseDataset(GraphDataset): - - max_batch_span = 512 - - def __init__( - self, - graph, - img_config, - is_multimodal=False, - min_search_size=0, - prefetch=64, - step_size=40, - subgraph_radius=100, - ): - # Call parent class - super().__init__( - graph, - img_config, - is_multimodal=is_multimodal, - min_search_size=min_search_size, - prefetch=prefetch, - subgraph_radius=subgraph_radius, - ) - - # Instance attributes - self.patch_loader = DetectionBatchLoader(self.graph, img_config) - self.search_mode = "dense" - self.step_size = step_size - - def _generate_batches_from_component(self, root): - # Subroutines - def submit_thread(): - try: - nodes = next(batch_nodes_generator) - thread = executor.submit(self.patch_loader, nodes) - pending[thread] = nodes - except StopIteration: - pass - - # Main - batch_nodes_generator = self._generate_batch_nodes(root) - with ThreadPoolExecutor(max_workers=128) as executor: - try: - # Prefetch batches - pending = dict() - for _ in range(self.prefetch): - submit_thread() - - # Yield batches - while pending: - done, _ = wait(pending.keys(), return_when=FIRST_COMPLETED) - for thread in done: - # Process completed thread - nodes = pending.pop(thread) - img, offset = thread.result() - yield from self.generate_inputs(nodes, img, offset) - - # Continue submitting threads - submit_thread() - finally: - pass - - def _generate_batch_nodes(self, root): - """ - Generates batches of nodes from the connected component that contains - the given root node. - - Returns - ------- - Iterator[numpy.ndarray] - Generator that yields batches of nodes from the connected - component containing the given root node. - """ - nodes = list() - for i, j in nx.dfs_edges(self.graph, source=root): - # Check if starting new batch - self.distance_traversed += self.dist(i, j) - if len(nodes) == 0: - if self.is_node_valid(i): - root = i - last_node = i - nodes.append(i) - else: - continue - - # Check whether to yield batch - if self.dist(root, j) > self.max_batch_span: - # Yield nodes in batch - yield np.array(nodes, dtype=int) - - # Reset batch metadata - nodes = list() - - # Visit j - is_next = self.dist(last_node, j) >= self.step_size - 2 - is_branching = self.degree[j] >= 3 - if (is_next or is_branching) and self.is_node_valid(j): - last_node = j - nodes.append(j) - if len(nodes) == 1: - root = j - - # Yield any remaining nodes after the loop - if nodes: - yield np.array(nodes, dtype=int) - - def generate_patches(self, nodes, img, offset): - voxels = np.array([self.node_voxel(i) for i in nodes], dtype=int) - for node, center in zip(nodes, voxels - offset): - s = img_util.get_slices(center, self.patch_shape) - patch = torch.from_numpy(img[(slice(0, 2), *s)]).float() - yield node, patch - - def _get_multimodal_batch(self, nodes, img, offset): - # Initializations - label_mask = self.get_label_mask(nodes, img.shape, offset) - patch_centers = self.get_patch_centers(nodes) - offset - - # Populate batch array - patches = np.empty((len(nodes), 2,) + self.patch_shape) - point_clouds = np.empty((len(nodes), 3, 3600), dtype=np.float32) - for i, (node, center) in enumerate(zip(nodes, patch_centers)): - s = img_util.get_slices(center, self.patch_shape) - patches[i, 0, ...] = img_util.normalize(img[s]) - patches[i, 1, ...] = label_mask[s] - - subgraph = self.graph.rooted_subgraph(node, self.subgraph_radius) - point_clouds[i] = subgraph_to_point_cloud(subgraph) - - # Build batch dictionary - batch = ml_util.TensorDict( - { - "img": ml_util.to_tensor(patches), - "point_cloud": ml_util.to_tensor(point_clouds), - } - ) - return nodes, batch - - # --- Helpers --- - def estimate_iterations(self): - """ - Estimates the number of iterations required to search graph. - - Returns - ------- - int - Estimated number of iterations required to search graph. - """ - # Search graph - total_cable_length = 0 - n_fragments = 0 - for nodes in map(list, nx.connected_components(self.graph)): - cable_length = self.cable_length(root=nodes[0]) - if cable_length > self.min_size: - total_cable_length += cable_length - n_fragments += 1 - - # Report results - print("# Fragments:", n_fragments) - print(f"Total Cable Length: {total_cable_length / 1000:.2f}mm") - return int(total_cable_length / self.step_size) - - -class SparseDataset(GraphDataset): - - def _generate_batches_from_component(self): - pass - - def _generate_batch_nodes(self, root): - nodes = list() - patch_centers = list() - for i, j in nx.dfs_edges(self.graph, source=root): - # Check if starting new batch - self.distance_traversed += self.graph.dist(i, j) - if len(patch_centers) == 0 and self.graph.degree[i] > 2: - root = i - nodes.append(i) - patch_centers.append(self.graph.node_voxel(i)) - - # Check whether to yield batch - if self.graph.dist(root, j) > self.max_batch_span: - # Yield batch metadata - patch_centers = np.array(patch_centers, dtype=int) - nodes = np.array(nodes, dtype=int) - yield nodes, patch_centers - - # Reset batch metadata - nodes = list() - patch_centers = list() - - # Visit j - if self.graph.degree[j] > 2: - nodes.append(j) - patch_centers.append(self.graph.node_voxel(j)) - if len(patch_centers) == 1: - root = j - - # --- Helpers --- - def estimate_iterations(self): - """ - Estimates the number of iterations required to search graph. - - Returns - ------- - int - Estimated number of iterations required to search graph. - """ - return len(self.graph.get_branchings()) - - -class BranchingDataset(GraphDataset): - - def __init__( - self, - graph, - img_config, - is_multimodal=False, - min_search_size=0, - prefetch=64, - step_size=10, - subgraph_radius=100, - ): - # Call parent class - super().__init__( - graph, - img_config, - is_multimodal=is_multimodal, - min_search_size=min_search_size, - prefetch=prefetch, - subgraph_radius=subgraph_radius, - ) - - # Instance attributes - self.patch_loader = DetectionPatchLoader(self.graph, img_config) - self.search_mode = "branching_nodes" diff --git a/src/neuron_proofreader/merge_proofreading/search_datasets.py b/src/neuron_proofreader/merge_proofreading/search_datasets.py new file mode 100644 index 00000000..473685d1 --- /dev/null +++ b/src/neuron_proofreader/merge_proofreading/search_datasets.py @@ -0,0 +1,319 @@ +""" +Created on Wed August 4 16:00:00 2025 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Code for detecting merge mistakes on skeletons generated from an automated +image segmentation. + +""" + +from abc import ABC, abstractmethod +from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait +from torch.utils.data import IterableDataset + +import networkx as nx +import numpy as np +import torch + +from neuron_proofreader.machine_learning.point_cloud_models import ( + subgraph_to_point_cloud, +) +from neuron_proofreader.machine_learning.image_dataloader import ( + DetectionBatchLoader, + DetectionPatchLoader, +) +from neuron_proofreader.utils import img_util, ml_util, util + + +# --- Datasets --- +class SearchDataset(IterableDataset, ABC): + + def __init__( + self, + graph, + img_config, + is_multimodal=False, + min_search_size=0, + prefetch=32, + subgraph_radius=100, + ): + # Call parent class + super().__init__() + + # Instance attributes + self.distance_traversed = 0 + self.graph = graph + self.is_multimodal = is_multimodal + self.min_size = min_search_size + self.patch_shape = img_config.patch_shape + self.prefetch = prefetch + self.subgraph_radius = subgraph_radius + + # Input getter + if is_multimodal: + self.get_input = self.get_patch_and_pointcloud + else: + self.get_input = self.get_patch + + # --- Core routines --- + def __iter__(self): + sites_generator = self._all_sites() + with ThreadPoolExecutor(max_workers=self.prefetch) as executor: + # Submit initial threads + pending = {} + for _ in range(self.prefetch): + self._submit(executor, sites_generator, pending) + + # Process remaining jobs until complete + while pending: + done, _ = wait(pending.keys(), return_when=FIRST_COMPLETED) + for thread in done: + site = pending.pop(thread) + yield from self.get_input(site, *thread.result()) + self._submit(executor, sites_generator, pending) + + def _submit(self, executor, sites_generator, pending): + try: + site = next(sites_generator) + pending[executor.submit(self.patch_loader, site)] = site + except StopIteration: + pass + + def _all_sites(self): + visited_ids = set() + valid_ids = self.find_fragments_to_search() + for u in self.graph.leaf_nodes(): + component_id = self.node_component_id[u] + if component_id not in visited_ids and component_id in valid_ids: + visited_ids.add(component_id) + yield from self.generate_component_sites(u) + + @abstractmethod + def generate_component_sites(self, root): + """ + Abstract method to be implemented by subclasses. + """ + pass + + # --- Helpers --- + def __getattr__(self, name): + return getattr(self.graph, name) + + @abstractmethod + def estimate_iterations(self): + pass + + def find_fragments_to_search(self): + component_ids = set() + for nodes in nx.connected_components(self.graph): + # Compute path length + node = util.sample_once(list(nodes)) + length = self.graph.cable_length( + max_depth=self.min_size, root=node + ) + + # Check if path length satisfies threshold + if length > self.min_size: + component_ids.add(self.node_component_id[node]) + return component_ids + + def is_contained(self, node): + voxel = self.node_voxel(node) + shape = self.patch_loader.img.shape()[2::] + buffer = np.max(self.patch_shape) + 1 + return img_util.is_contained(voxel, shape, buffer=buffer) + + def is_near_leaf(self, node, threshold=32): + # Check if node is branching + if self.degree[node] > 2: + return False + + # Search neighborhood + queue = [(node, 0)] + visited = {node} + while len(queue) > 0: + # Visit node + i, dist_i = queue.pop() + if self.degree[i] == 1: + return True + + # Update queue + for j in self.neighbors(i): + dist_j = dist_i + self.dist(i, j) + if j not in visited and dist_j < threshold: + queue.append((j, dist_j)) + visited.add(j) + return False + + def is_node_valid(self, node): + is_contained = self.is_contained(node) + is_nonleaf = not self.is_near_leaf(node) + return is_contained and is_nonleaf + + +class DenseSearchDataset(SearchDataset): + + max_batch_span = 512 + + def __init__( + self, + graph, + img_config, + is_multimodal=False, + min_search_size=0, + prefetch=64, + step_size=40, + subgraph_radius=100, + ): + # Call parent class + super().__init__( + graph, + img_config, + is_multimodal=is_multimodal, + min_search_size=min_search_size, + prefetch=prefetch, + subgraph_radius=subgraph_radius, + ) + + # Instance attributes + self.patch_loader = DetectionBatchLoader(self.graph, img_config) + self.search_mode = "dense" + self.step_size = step_size + + def generate_component_sites(self, root): + """ + Generates batches of nodes from the connected component that contains + the given root node. + + Returns + ------- + Iterator[numpy.ndarray] + Generator that yields batches of nodes from the connected + component containing the given root node. + """ + nodes = list() + for i, j in nx.dfs_edges(self.graph, source=root): + # Check if starting new batch + self.distance_traversed += self.dist(i, j) + if len(nodes) == 0: + if self.is_node_valid(i): + root = i + last_node = i + nodes.append(i) + else: + continue + + # Check whether to yield batch + if self.dist(root, j) > self.max_batch_span: + yield np.array(nodes, dtype=int) + nodes = list() + + # Visit j + is_next = self.dist(last_node, j) >= self.step_size - 2 + is_branching = self.degree[j] >= 3 + if (is_next or is_branching) and self.is_node_valid(j): + last_node = j + nodes.append(j) + if len(nodes) == 1: + root = j + + # Yield any remaining nodes after the loop + if nodes: + yield np.array(nodes, dtype=int) + + def get_patch(self, nodes, img, offset): + img = torch.from_numpy(img).float() + voxels = np.array([self.node_voxel(i) for i in nodes], dtype=int) + for node, center in zip(nodes, voxels - offset): + s = img_util.get_slices(center, self.patch_shape) + yield node, img[(slice(0, 2), *s)] + + def generate_patch_and_pc(self, nodes, img, offset): + # Initializations + label_mask = self.get_label_mask(nodes, img.shape, offset) + patch_centers = self.get_patch_centers(nodes) - offset + + # Populate batch array + patches = np.empty((len(nodes), 2,) + self.patch_shape) + point_clouds = np.empty((len(nodes), 3, 3600), dtype=np.float32) + for i, (node, center) in enumerate(zip(nodes, patch_centers)): + s = img_util.get_slices(center, self.patch_shape) + patches[i, 0, ...] = img_util.normalize(img[s]) + patches[i, 1, ...] = label_mask[s] + + subgraph = self.graph.rooted_subgraph(node, self.subgraph_radius) + point_clouds[i] = subgraph_to_point_cloud(subgraph) + + # Build batch dictionary + batch = ml_util.TensorDict( + { + "img": ml_util.to_tensor(patches), + "point_cloud": ml_util.to_tensor(point_clouds), + } + ) + return nodes, batch + + # --- Helpers --- + def estimate_iterations(self): + """ + Estimates the number of iterations required to search graph. + + Returns + ------- + int + Estimated number of iterations required to search graph. + """ + # Search graph + total_cable_length = 0 + n_fragments = 0 + for nodes in map(list, nx.connected_components(self.graph)): + cable_length = self.cable_length(root=nodes[0]) + if cable_length > self.min_size: + total_cable_length += cable_length + n_fragments += 1 + + # Report results + print("# Fragments:", n_fragments) + print(f"Total Cable Length: {total_cable_length / 1000:.2f}mm") + return int(total_cable_length / self.step_size) + + +class SparseSearchDataset(SearchDataset): + pass + + +class BranchingSearchDataset(SearchDataset): + + def __init__( + self, + graph, + img_config, + is_multimodal=False, + min_search_size=0, + prefetch=64, + step_size=10, + subgraph_radius=100, + ): + # Call parent class + super().__init__( + graph, + img_config, + is_multimodal=is_multimodal, + min_search_size=min_search_size, + prefetch=prefetch, + subgraph_radius=subgraph_radius, + ) + + # Instance attributes + self.patch_loader = DetectionPatchLoader(self.graph, img_config) + self.search_mode = "branching_nodes" + + def generate_component_sites(self, root): + for i, j in nx.dfs_edges(self.graph, source=root): + if self.degree[i] >= 3: + yield i + + def get_patch(self, node, img): + yield node, torch.from_numpy(img).float() From 39d121069ead3d79d208450779c74128ceaae985 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 16 Jun 2026 03:44:08 +0000 Subject: [PATCH 5/5] resolve merge conflict --- src/neuron_proofreader/merge_proofreading/merge_detection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neuron_proofreader/merge_proofreading/merge_detection.py b/src/neuron_proofreader/merge_proofreading/merge_detection.py index 280009be..c0736a06 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_detection.py +++ b/src/neuron_proofreader/merge_proofreading/merge_detection.py @@ -32,7 +32,7 @@ def __init__( batch_size=16, device="cuda", remove_detected_sites=False, - threshold=0.4, + threshold=0.5, ): # Instance attributes self.batch_size = batch_size