From fef942df682a95ab8ee2ed9d3bedec8acbf7543a Mon Sep 17 00:00:00 2001 From: anna-grim Date: Wed, 27 May 2026 19:37:37 +0000 Subject: [PATCH 01/11] refactor: update merge training --- doc_template/source/conf.py | 1 + src/neuron_proofreader/__init__.py | 1 + .../machine_learning/augmentation.py | 25 +- .../machine_learning/exaspim_dataloader.py | 20 +- .../machine_learning/geometric_gnn_models.py | 23 +- .../machine_learning/gnn_models.py | 21 +- .../machine_learning/train.py | 8 +- .../machine_learning/vision_models.py | 2 +- .../merge_proofreading/merge_dataloading.py | 4 +- .../merge_datamodules_v2.py | 630 ++++++++++++++++++ .../merge_proofreading/merge_datasets.py | 24 +- .../merge_proofreading/merge_inference.py | 30 +- src/neuron_proofreader/proposal_graph.py | 5 +- src/neuron_proofreader/skeleton_graph.py | 2 +- .../split_proofreading/proposal_generation.py | 2 +- .../split_feature_extraction.py | 5 +- .../split_proofreading/split_inference.py | 8 +- src/neuron_proofreader/utils/graph_util.py | 17 +- src/neuron_proofreader/utils/img_util.py | 185 ++--- src/neuron_proofreader/utils/swc_util.py | 597 ++++++++--------- src/neuron_proofreader/utils/util.py | 187 ++++-- tests/__init__.py | 1 + 22 files changed, 1221 insertions(+), 577 deletions(-) create mode 100644 src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py diff --git a/doc_template/source/conf.py b/doc_template/source/conf.py index b38833ac..6e55fb07 100644 --- a/doc_template/source/conf.py +++ b/doc_template/source/conf.py @@ -1,4 +1,5 @@ """Configuration file for the Sphinx documentation builder.""" + # # For the full list of built-in configuration values, see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html diff --git a/src/neuron_proofreader/__init__.py b/src/neuron_proofreader/__init__.py index d0a85479..024d51f2 100644 --- a/src/neuron_proofreader/__init__.py +++ b/src/neuron_proofreader/__init__.py @@ -1,2 +1,3 @@ """Init package""" + __version__ = "0.0.0" diff --git a/src/neuron_proofreader/machine_learning/augmentation.py b/src/neuron_proofreader/machine_learning/augmentation.py index 91951eba..25429dce 100644 --- a/src/neuron_proofreader/machine_learning/augmentation.py +++ b/src/neuron_proofreader/machine_learning/augmentation.py @@ -30,7 +30,7 @@ def __init__(self): RandomFlip3D(), RandomRotation3D(), RandomNoise3D(), - RandomContrast3D() + RandomContrast3D(), ] def __call__(self, patches): @@ -79,8 +79,8 @@ def __call__(self, patches): """ for axis in self.axes: if random.random() > 0.5: - patches[0, ...] = np.flip(patches[0, ...], axis=axis) - patches[1, ...] = np.flip(patches[1, ...], axis=axis) + patches[0] = np.flip(patches[0], axis=axis) + patches[1] = np.flip(patches[1], axis=axis) return patches @@ -116,8 +116,8 @@ def __call__(self, patches): for axes in self.axes: if random.random() < 0.5: angle = random.uniform(*self.angles) - self.rotate3d(patches[0, ...], angle, axes, False) - self.rotate3d(patches[1, ...], angle, axes, True) + patches[0] = self.rotate3d(patches[0], angle, axes, False) + patches[1] = self.rotate3d(patches[1], angle, axes, True) return patches @staticmethod @@ -149,6 +149,7 @@ def rotate3d(img_patch, angle, axes, is_segmentation=False): order=order, ) img_patch /= multipler + return img_patch class RandomScale3D: @@ -197,8 +198,8 @@ def __call__(self, patches): ] # Rescale images - patches[0, ...] = zoom(patches[0, ...], zoom_factors, order=3) - patches[1, ...] = zoom(patches[1, ...], zoom_factors, order=0) + patches[0] = zoom(patches[0], zoom_factors, order=3) + patches[1] = zoom(patches[1], zoom_factors, order=0) return patches @@ -208,7 +209,7 @@ class RandomContrast3D: Adjusts the contrast of a 3D image by scaling voxel intensities. """ - def __init__(self, p_low=(0, 90), p_high=(97.5, 100)): + def __init__(self, p_low=(0, 80), p_high=(98, 100)): """ Initializes a RandomContrast3D transformer. @@ -253,7 +254,7 @@ def __init__(self, max_std=0.2): """ self.max_std = max_std - def __call__(self, img_patches): + def __call__(self, patches): """ Adds Gaussian noise to the input 3D image. @@ -264,6 +265,6 @@ def __call__(self, img_patches): the input image and "patches[1, ...]" is from the segmentation. """ std = self.max_std * random.random() - img_patches[0] += np.random.uniform(-std, std, img_patches[0].shape) - img_patches[0] = np.clip(img_patches[0], 0, 1) - return img_patches + patches[0] += np.random.uniform(-std, std, patches[0].shape) + patches[0] = np.clip(patches[0], 0, 1) + return patches diff --git a/src/neuron_proofreader/machine_learning/exaspim_dataloader.py b/src/neuron_proofreader/machine_learning/exaspim_dataloader.py index 997a7cca..7724dc86 100644 --- a/src/neuron_proofreader/machine_learning/exaspim_dataloader.py +++ b/src/neuron_proofreader/machine_learning/exaspim_dataloader.py @@ -328,9 +328,7 @@ def sample_bright_voxel(self, brain_id): pending = dict() for _ in range(self.prefetch_foreground_sampling): voxel = self.sample_interior_voxel(brain_id) - thread = executor.submit( - self.read_image, brain_id, voxel - ) + thread = executor.submit(self.read_image, brain_id, voxel) pending[thread] = voxel # Check if image patch is bright enough @@ -489,8 +487,20 @@ def _load_batch(self, start_idx): ) # Process results - img_patches = np.zeros((batch_size, 1,) + self.patch_shape) - mask_patches = np.zeros((batch_size, 1,) + self.patch_shape) + img_patches = np.zeros( + ( + batch_size, + 1, + ) + + self.patch_shape + ) + mask_patches = np.zeros( + ( + batch_size, + 1, + ) + + self.patch_shape + ) for i, process in enumerate(as_completed(processes)): img, mask = process.result() img_patches[i, 0, ...] = img diff --git a/src/neuron_proofreader/machine_learning/geometric_gnn_models.py b/src/neuron_proofreader/machine_learning/geometric_gnn_models.py index 4f65c90d..0c68d7f1 100644 --- a/src/neuron_proofreader/machine_learning/geometric_gnn_models.py +++ b/src/neuron_proofreader/machine_learning/geometric_gnn_models.py @@ -75,9 +75,7 @@ def __init__(self, ggnn_name, output_dim=64): # Set geometric gnn if ggnn_name == "egnn": self.geometric_gnn = EGNN( - in_node_dim=1, - hidden_dim=32, - out_node_dim=output_dim + in_node_dim=1, hidden_dim=32, out_node_dim=output_dim ) # --- Core Routines --- @@ -97,7 +95,9 @@ def forward(self, h, x, edge_index, batch): ) # Pool node embeddings - h_g, x_g, edge_index_g = self.pool_nonbranching_paths(h_g, x_g, edge_index_g) + h_g, x_g, edge_index_g = self.pool_nonbranching_paths( + h_g, x_g, edge_index_g + ) # Encode pooled graph h_g = self.encode_pooled_graph(h_g, x_g, edge_index_g) @@ -158,7 +158,9 @@ def pool_nonbranching_paths(self, h, x, edge_index): # Finish h_pooled = torch.stack(h_pooled, dim=0) x_pooled = torch.stack(x_pooled, dim=0) - edge_index_pooled = self.get_edge_index_pooled(edge_index, node_to_path) + edge_index_pooled = self.get_edge_index_pooled( + edge_index, node_to_path + ) return h_pooled, x_pooled, edge_index_pooled def get_adj_and_deg(self, edge_index, num_nodes): @@ -201,10 +203,13 @@ def extract_subgraph(self, h, x, edge_index, node_mask): id_map = {int(n): i for i, n in enumerate(node_ids.tolist())} edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]] edge_index_g = edge_index[:, edge_mask] - edge_index_g = torch.stack([ - torch.tensor([id_map[int(u)] for u in edge_index_g[0]]), - torch.tensor([id_map[int(v)] for v in edge_index_g[1]]) - ], dim=0) + edge_index_g = torch.stack( + [ + torch.tensor([id_map[int(u)] for u in edge_index_g[0]]), + torch.tensor([id_map[int(v)] for v in edge_index_g[1]]), + ], + dim=0, + ) return h_g, x_g, edge_index_g @staticmethod diff --git a/src/neuron_proofreader/machine_learning/gnn_models.py b/src/neuron_proofreader/machine_learning/gnn_models.py index c8954ec3..b5add009 100644 --- a/src/neuron_proofreader/machine_learning/gnn_models.py +++ b/src/neuron_proofreader/machine_learning/gnn_models.py @@ -26,6 +26,7 @@ class VisionHGAT(torch.nn.Module): Heterogeneous graph attention network that processes multimodal features such as image patches and feature vectors. """ + # Class attributes relations = [ str(("branch", "to", "branch")), @@ -47,7 +48,9 @@ def __init__( # Initial embeddings self.node_embedding = init_node_embedding(hidden_dim) - self.patch_embedding = init_patch_embedding(patch_shape, hidden_dim // 2) + self.patch_embedding = init_patch_embedding( + patch_shape, hidden_dim // 2 + ) # Message passing layers self.disable_msg_passing = disable_msg_passing @@ -58,7 +61,7 @@ def __init__( else: self.gat1 = self.init_gat(hidden_dim, hidden_dim, heads) self.gat2 = self.init_gat(hidden_dim * heads, hidden_dim, heads) - self.output = nn.Linear(hidden_dim * heads ** 2, 1) + self.output = nn.Linear(hidden_dim * heads**2, 1) # Initialize weights self.init_weights() @@ -81,9 +84,7 @@ def init_mlp_layers(self, hidden_dim, n_layers=2): for _ in range(n_layers): layers.append( nn_geometric.HeteroDictLinear( - hidden_dim, - hidden_dim, - types=("branch", "proposal") + hidden_dim, hidden_dim, types=("branch", "proposal") ) ) return layers @@ -160,10 +161,12 @@ def init_node_embedding(output_dim): dim_p = node_input_dims["proposal"] # Set node embedding layer - node_embedding = nn.ModuleDict({ - "branch": FeedForwardNet(dim_b, output_dim, 3), - "proposal": FeedForwardNet(dim_p, output_dim // 2, 3), - }) + node_embedding = nn.ModuleDict( + { + "branch": FeedForwardNet(dim_b, output_dim, 3), + "proposal": FeedForwardNet(dim_p, output_dim // 2, 3), + } + ) return node_embedding diff --git a/src/neuron_proofreader/machine_learning/train.py b/src/neuron_proofreader/machine_learning/train.py index bce9b6e3..20edafb4 100644 --- a/src/neuron_proofreader/machine_learning/train.py +++ b/src/neuron_proofreader/machine_learning/train.py @@ -68,7 +68,7 @@ def __init__( lr=1e-3, max_epochs=200, min_recall=0, - save_mistake_mips=False + save_mistake_mips=False, ): """ Instantiates a Trainer object. @@ -292,7 +292,7 @@ def compute_stats(y, hat_y): "f1": avg_f1, "precision": avg_prec, "recall": avg_recall, - "accuracy": avg_acc + "accuracy": avg_acc, } return stats @@ -426,7 +426,7 @@ def __init__( device="cuda", lr=1e-3, max_epochs=200, - save_mistake_mips=False + save_mistake_mips=False, ): """ Instantiates a DistributedTrainer object. @@ -452,7 +452,7 @@ def __init__( device=device, lr=lr, max_epochs=max_epochs, - save_mistake_mips=save_mistake_mips + save_mistake_mips=save_mistake_mips, ) # Check that multiple GPUs are available diff --git a/src/neuron_proofreader/machine_learning/vision_models.py b/src/neuron_proofreader/machine_learning/vision_models.py index bf99fe65..21b6760d 100644 --- a/src/neuron_proofreader/machine_learning/vision_models.py +++ b/src/neuron_proofreader/machine_learning/vision_models.py @@ -142,7 +142,7 @@ def __init__(self, checkpoint_path, model_config): checkpoint_path=checkpoint_path, model_config=model_config, task_head_config="binary_classifier", - freeze_encoder=True + freeze_encoder=True, ) # Instance attributes diff --git a/src/neuron_proofreader/merge_proofreading/merge_dataloading.py b/src/neuron_proofreader/merge_proofreading/merge_dataloading.py index 0bf8656d..d9c6fd50 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_dataloading.py +++ b/src/neuron_proofreader/merge_proofreading/merge_dataloading.py @@ -43,7 +43,9 @@ def load_fragments(dataset, is_test=False): sub_df = merge_sites_df.loc[merge_sites_df["brain_id"] == brain_id] for segmentation_id in sub_df["segmentation_id"].unique(): if (brain_id, segmentation_id) in target_pairs: - swc_pointer = f"{root}/{brain_id}/{segmentation_id}/merged_fragments.zip" + swc_pointer = ( + f"{root}/{brain_id}/{segmentation_id}/merged_fragments.zip" + ) dataset.load_fragment_graphs( brain_id, swc_pointer, use_anisotropy=False ) diff --git a/src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py b/src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py new file mode 100644 index 00000000..585892f4 --- /dev/null +++ b/src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py @@ -0,0 +1,630 @@ +""" +Created on Tue June 26 12:00:00 2026 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Dataset and dataloader utilities for processing merge site data to train a +model to detect merge errors. + +Architecture +------------ +BrainDataset + Owns all data for a single brain: fragment graph, GT graph, image/ + segmentation readers, merge-site KD-tree, and the slice of + merge_sites_df that belongs to this brain. Handles all per-brain + positive/negative site retrieval. + +BrainDatasetCollection + Holds an ordered list of BrainDataset objects. Routes global indices to + the correct brain, exposes split() for train/val partitioning, and is + the object handed to MergeSiteDataLoader. + +MergeSiteDataLoader + Custom DataLoader that uses multithreading to fetch image patches from + cloud storage and assemble batches. +""" + +from scipy.spatial import KDTree +from concurrent.futures import as_completed, ThreadPoolExecutor +from torch.utils.data import Dataset, DataLoader + +import networkx as nx +import numpy as np +import os +import queue +import random +import threading +import torch + +from neuron_proofreader.machine_learning.augmentation import ImageTransforms +from neuron_proofreader.machine_learning.geometric_gnn_models import ( + subgraph_to_data, +) +from neuron_proofreader.machine_learning.point_cloud_models import ( + subgraph_to_point_cloud, +) +from neuron_proofreader.skeleton_graph import SkeletonGraph +from neuron_proofreader.utils import ( + geometry_util, + img_util, + ml_util, + swc_util, + util, +) + + +# --------------------------------------------------------------------------- +# BrainDataset +# --------------------------------------------------------------------------- + + +class BrainDataset: + """ + All data and retrieval logic for a single whole-brain dataset. + + Parameters + ---------- + brain_id : str + Unique identifier for this brain. + anisotropy : Tuple[float] + Voxel-to-physical scaling factors. + brightness_clip : int + Maximum raw image intensity before normalisation. + subgraph_radius : int + Radius (um) used when extracting rooted subgraphs. + node_spacing : int + Spacing (um) between neighbouring graph nodes. + patch_shape : Tuple[int] + Shape of the 3D image patches to extract. + use_segmentation_mask : bool + Whether to overlay a volumetric segmentation when building the + segment mask. + """ + + def __init__( + self, + brain_id, + sites_prefix, + pos_site_paths, + neg_site_paths, + img_path, + anisotropy=(1.0, 1.0, 1.0), + brightness_clip=500, + subgraph_radius=100, + node_spacing=5, + patch_shape=(128, 128, 128), + probability_random_nonmerge_site=0.5, + use_transform=False, + use_segmentation_mask=False, + ): + # Instance attributes + self.anisotropy = anisotropy + self.brain_id = brain_id + self.brightness_clip = brightness_clip + self.subgraph_radius = subgraph_radius + self.node_spacing = node_spacing + self.patch_shape = patch_shape + self.probability_random_nonmerge_site = ( + probability_random_nonmerge_site + ) + self.use_segmentation_mask = use_segmentation_mask + + # Core data structures + self.graph = self.load_fragments(sites_prefix) + self.img = img_util.TensorStoreImage(img_path) + self.segmentation_reader = None + + self.nonmerge_sites = self.load_sites(neg_site_paths) + self.merge_sites = self.load_sites(pos_site_paths) + self.set_merge_site_info() + + # Image augmentation for training + self.transform = ImageTransforms() if use_transform else None + + def load_fragments(self, sites_prefix): + graph = SkeletonGraph( + anisotropy=self.anisotropy, + node_spacing=self.node_spacing, + use_anisotropy=False, + verbose=True, + ) + graph.load(os.path.join(sites_prefix, "fragments")) + return graph + + def load_sites(self, site_paths): + sites = list() + swc_reader = swc_util.Reader(verbose=False) + for swc_dict in swc_reader(site_paths): + sites.append(swc_dict["xyz"]) + return np.vstack(sites) + + def set_merge_site_info(self): + # Build kdtree of merge sites + self.merge_sites_kdtree = KDTree(self.merge_sites) + + # Store fragment IDs corresponding to merge sites + self.fragments_with_merge = set() + for xyz in self.merge_sites: + _, ii = self.graph.kdtree.query(xyz) + self.fragments_with_merge.add(self.graph.node_component_id[ii]) + + # --- Site retrieval --- + def __getitem__(self, idx): + # Get example + node, label = self.get_site(idx) + subgraph = self.graph.rooted_subgraph(node, self.subgraph_radius) + + # Get voxel coordinate + voxel = subgraph.node_voxel(0) + if self.transform: + voxel += np.random.randint(-6, 6 + 1, size=3) + + # Extract subgraph and image patches centered at site + img_patch = self.get_img_patch(voxel) + segment_mask = self.get_segment_mask(voxel, subgraph) + + # Stack image channels + try: + patches = np.stack([img_patch, segment_mask], axis=0) + except ValueError: + img_patch = img_util.pad_to_shape(img_patch, self.patch_shape) + patches = np.stack([img_patch, segment_mask], axis=0) + return patches, subgraph, label + + def get_site(self, idx): + if idx > 0: + return self.get_merge_site(idx) + elif np.random.random() < self.probability_random_nonmerge_site: + return self.get_random_nonmerge_site() + elif abs(idx) < len(self.nonmerge_sites): + return self.get_indexed_nonmerge_site(abs(idx)) + else: + return self.get_random_nonmerge_site() + + def get_merge_site(self, idx): + _, node = self.graph.kdtree.query(self.merge_sites[idx]) + return node, 1 + + def get_indexed_nonmerge_site(self, idx): + _, node = self.graph.kdtree.query(self.nonmerge_sites[idx]) + return node, 0 + + def get_random_nonmerge_site(self): + # Search for valid nonmerge site + branching_nodes = self.graph.branching_nodes() + use_branching = branching_nodes and random.random() < 0.5 + for cnt in range(10**4): + # Sample node + if use_branching: + node = util.sample_once(branching_nodes) + else: + node = util.sample_once(self.graph.nodes) + + # Reject if high-degree + if self.graph.degree(node) > 3: + continue + + # Reject if branching and near another branching node + if use_branching and self._has_nearby_branching(node): + continue + + # Reject if near merge site + dd, _ = self.merge_sites_kdtree.query(self.graph.node_xyz[node]) + if dd < 100: + continue + + # Site is valid + break + + return node, 0 + + # --- Image / Mask Extraction --- + def get_img_patch(self, center): + """ + Extracts, clips, and normalises a 3D image patch centred at center. + + Parameters + ---------- + center : numpy.ndarray + Voxel coordinates of the patch centre. + + Returns + ------- + numpy.ndarray + """ + patch = self.img.read(center, self.patch_shape) + patch = np.minimum(patch, self.brightness_clip) + return img_util.normalize(patch) + + def get_segment_mask(self, center, subgraph): + """ + Builds the segment mask for subgraph, optionally incorporating a + volumetric segmentation read. + + Parameters + ---------- + center : numpy.ndarray + Voxel coordinates of the patch centre. + subgraph : SkeletonGraph + + Returns + ------- + numpy.ndarray + """ + if self.use_segmentation_mask: + return self._segment_mask_with_segmentation(center, subgraph) + return self._segment_mask_skeleton_only(subgraph) + + def _segment_mask_skeleton_only(self, subgraph): + mask = np.zeros(self.patch_shape) + center = subgraph.node_voxel(0) + offset = img_util.get_offset(center, self.patch_shape) + for node1, node2 in subgraph.edges: + v1 = subgraph.node_local_voxel(node1, offset) + v2 = subgraph.node_local_voxel(node2, offset) + img_util.annotate_voxels( + mask, geometry_util.make_digital_line(v1, v2) + ) + return mask + + def _segment_mask_with_segmentation(self, center, subgraph): + mask = self.segmentation_reader.read(center, self.patch_shape) + mask = img_util.remove_small_segments(mask, 1000) + mask = 0.5 * (mask > 0).astype(float) + offset = img_util.get_offset(center, self.patch_shape) + for node1, node2 in subgraph.edges: + v1 = subgraph.node_local_voxel(node1, offset) + v2 = subgraph.node_local_voxel(node2, offset) + img_util.annotate_voxels( + mask, geometry_util.make_digital_line(v1, v2) + ) + return mask + + # --- Private helpers --- + def _list_indices(self): + # Set idxs + pos_idxs = np.arange(len(self.merge_sites)) + neg_idxs = np.arange(len(self.nonmerge_sites)) + + # Check for class imbalance + if len(neg_idxs) < len(pos_idxs): + neg_idxs = -pos_idxs + else: + neg_idxs = -np.random.choice( + neg_idxs, size=len(pos_idxs), replace=False + ) + return np.concatenate((pos_idxs, neg_idxs)) + + def _has_nearby_branching(self, root, max_depth=60): + queue = [(root, 0)] + visited = {root} + while queue: + # Visit node + i, d_i = queue.pop() + if self.graph.degree[i] > 2 and d_i > 0: + return True + + # Update queue + for j in self.graph.neighbors(i): + d_j = d_i + self.graph.dist(i, j) + if j not in visited and d_j < max_depth: + queue.append((j, d_j)) + visited.add(j) + return False + + def __len__(self): + return len(self._list_indices()) + + +# --------------------------------------------------------------------------- +# BrainDatasetCollection +# --------------------------------------------------------------------------- + + +class BrainDatasetCollection(Dataset): + """ + An ordered collection of BrainDataset objects that presents a unified + Dataset interface to MergeSiteDataLoader. + + Global indices are mapped to (brain_idx, local_idx) pairs via a flat + index table built from each brain's _list_indices(). The BrainDataset + handles all site dispatch and retrieval internally. + + Parameters + ---------- + brain_datasets : List[BrainDataset] + One BrainDataset per brain. + augmentation : callable or None + Applied to (2, D, H, W) patch arrays in-place during __getitem__. + Pass None when augmentation is not needed (e.g. validation). + """ + + def __init__(self, brain_datasets): + self.brain_datasets = brain_datasets + self._index_table = self._build_index_table() + + def _build_index_table(self): + """ + Builds a flat list of (brain_idx, local_idx) pairs by concatenating + each brain's _list_indices(). Rebuilt whenever the collection changes. + + Returns + ------- + List[Tuple[int, int]] + """ + table = [] + for b_idx, bd in enumerate(self.brain_datasets): + for local_idx in bd._list_indices(): + table.append((b_idx, int(local_idx))) + return table + + # --- Dataset interface --- + def __len__(self): + return len(self._index_table) + + def __getitem__(self, idx): + """ + Returns one example: (patches, subgraph, label). + + Parameters + ---------- + idx : int + Index into the flat index table. + + Returns + ------- + patches : numpy.ndarray shape (2, D, H, W) + subgraph : SkeletonGraph + label : int + """ + b_idx, local_idx = self._index_table[idx] + return self.brain_datasets[b_idx][local_idx] + + def get_idxs(self): + """ + Returns shuffleable indices over the full index table. + + Returns + ------- + numpy.ndarray + """ + return np.arange(len(self._index_table)) + + # --- Helpers --- + + def brain_ids(self): + """Returns the list of brain IDs in this collection.""" + return [bd.brain_id for bd in self.brain_datasets] + + def n_merge_sites(self): + """Returns the total number of merge sites across all brains.""" + return sum(len(bd.merge_sites) for bd in self.brain_datasets) + + def count_fragments(self): + """Returns the total number of fragments across all brains.""" + return sum( + nx.number_connected_components(bd.graph) + for bd in self.brain_datasets + if bd.graph is not None + ) + + def __repr__(self): + return ( + f"BrainDatasetCollection(" + f"n_brains={len(self.brain_datasets)}, " + f"n_examples={len(self)})" + ) + + +# --------------------------------------------------------------------------- +# MergeSiteDataLoader +# --------------------------------------------------------------------------- + + +class ThreadedDataLoader(DataLoader): + + _VALID_MODALITIES = {None, "graph", "pointcloud"} + + def __init__( + self, + dataset, + batch_size=32, + is_multimodal=False, + modality=None, + sampler=None, + use_shuffle=True, + prefetch_batches=8, + ): + # Check that modality is valid + if modality not in self._VALID_MODALITIES: + raise ValueError( + f"modality must be one of {self._VALID_MODALITIES}, " + f"got {modality!r}." + ) + + # Call parent class + super().__init__(dataset, batch_size=batch_size, sampler=sampler) + + # Instance attributes + self.is_multimodal = is_multimodal + self.modality = modality + self.use_shuffle = use_shuffle + self.prefetch_batches = prefetch_batches + self.patches_shape = (2,) + dataset.brain_datasets[0].patch_shape + + # Set batch loader + if self.is_multimodal and self.modality == "graph": + self._load_batch = self._load_image_graph_batch + elif self.is_multimodal and self.modality == "pointcloud": + self._load_batch = self._load_image_pc_batch + else: + self._load_batch = self._load_image_batch + + def __iter__(self): + # Extract indices + self.dataset._index_table = self.dataset._build_index_table() + idxs = self.dataset.get_idxs() + if self.use_shuffle: + np.random.shuffle(idxs) + + # Split into batches upfront + batch_index_groups = [ + idxs[start: min(start + self.batch_size, len(idxs))] + for start in range(0, len(idxs), self.batch_size) + ] + + # Sentinel signalling the prefetch thread is done + _DONE = object() + buffer = queue.Queue(maxsize=self.prefetch_batches) + + def prefetch_worker(): + try: + for batch_idxs in batch_index_groups: + buffer.put(self._load_batch(batch_idxs)) + except Exception as e: + buffer.put(e) + finally: + buffer.put(_DONE) + + thread = threading.Thread(target=prefetch_worker, daemon=True) + thread.start() + + while True: + item = buffer.get() + if item is _DONE: + break + if isinstance(item, Exception): + raise item + yield item + + thread.join() + + def _load_image_batch(self, batch_idxs): + """ + Loads a batch of samples from the dataset using multithreading. + + Parameters + ---------- + batch_idxs : List[int] + Indices of the dataset items to include in the batch. + + Returns + ------- + patches : torch.Tensor + Image patches for the batch. + targets : torch.Tensor + Target labels corresponding to each patch. + """ + with ThreadPoolExecutor() as executor: + # Assign threads + pending = dict() + for i, idx in enumerate(batch_idxs): + thread = executor.submit(self.dataset.__getitem__, idx) + pending[thread] = i + + # Store results + patches = np.zeros((len(batch_idxs),) + self.patches_shape) + targets = np.zeros((len(batch_idxs), 1)) + for thread in as_completed(pending.keys()): + i = pending.pop(thread) + patches[i], _, targets[i] = thread.result() + return ml_util.to_tensor(patches), ml_util.to_tensor(targets) + + def _load_image_pc_batch(self, batch_idxs): + """ + Loads a batch of samples from the dataset using multithreading. + + Parameters + ---------- + batch_idxs : List[int] + Indices of the dataset items to include in the batch. + + Returns + ------- + batch : Dict[str, torch.Tensor] + Dictionary that maps modality names to batch features. + targets : torch.Tensor + Target labels corresponding to each patch. + """ + with ThreadPoolExecutor() as executor: + # Assign threads + pending = dict() + for i, idx in enumerate(batch_idxs): + thread = executor.submit(self.dataset.__getitem__, idx) + pending[thread] = i + + # Store results + patches = np.zeros((len(batch_idxs),) + self.patches_shape) + targets = np.zeros((len(batch_idxs), 1)) + point_clouds = np.zeros((len(batch_idxs), 3, 3600)) + for thread in as_completed(pending.keys()): + i = pending.pop(thread) + patches[i], subgraph, targets[i] = thread.result() + point_clouds[i] = subgraph_to_point_cloud(subgraph) + + # Set batch dictionary + batch = ml_util.TensorDict( + { + "img": ml_util.to_tensor(patches), + "point_cloud": ml_util.to_tensor(point_clouds), + } + ) + return batch, ml_util.to_tensor(targets) + + def _load_image_graph_batch(self, idxs): + """ + Loads a batch of samples from the dataset using multithreading. + + Parameters + ---------- + idxs : List[int] + Indices of the dataset items to include in the batch. + + Returns + ------- + batch : Dict[str, torch.Tensor] + Dictionary that maps modality names to batch features. + targets : torch.Tensor + Target labels corresponding to each patch. + """ + with ThreadPoolExecutor() as executor: + # Assign threads + threads = list() + for idx in idxs: + threads.append(executor.submit(self.dataset.__getitem__, idx)) + + # Store results + targets = np.zeros((len(idxs), 1)) + patches = np.zeros((len(idxs),) + self.patches_shape) + h, x, edge_index, batches = list(), list(), list(), list() + node_offset = 0 + for i, thread in enumerate(as_completed(threads)): + patches[i], subgraph, targets[i] = thread.result() + h_i, x_i, edge_index_i = subgraph_to_data(subgraph) + n_i = h_i.size(0) + + edge_index_i += node_offset + h.append(h_i) + x.append(x_i) + edge_index.append(edge_index_i) + batches.append( + torch.full((n_i,), i, dtype=torch.long) + ) + + node_offset += n_i + + # Combine subgraph batches + h = torch.cat(h, dim=0) + x = torch.cat(x, dim=0) + edge_index = torch.cat(edge_index, dim=1) + batches = torch.cat(batches, dim=0) + + # Set batch dictionary + batch = ml_util.TensorDict( + { + "img": ml_util.to_tensor(patches), + "graph": (h, x, edge_index, batches) + } + ) + return batch, ml_util.to_tensor(targets) diff --git a/src/neuron_proofreader/merge_proofreading/merge_datasets.py b/src/neuron_proofreader/merge_proofreading/merge_datasets.py index fe9f4420..63a838e0 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datasets.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datasets.py @@ -29,7 +29,7 @@ subgraph_to_point_cloud, ) from neuron_proofreader.merge_proofreading.merge_dataloading import ( - get_brain_merge_sites + get_brain_merge_sites, ) from neuron_proofreader.skeleton_graph import SkeletonGraph from neuron_proofreader.utils import ( @@ -71,6 +71,7 @@ class MergeSiteDataset(Dataset): patch_shape : Tuple[int], optional Shape of the 3D image patches to extract. """ + random_negative_example_prob = 0.8 def __init__( @@ -121,9 +122,7 @@ def __init__( self.merge_site_kdtrees = dict() # --- Load Data --- - def load_fragment_graphs( - self, brain_id, swc_pointer, use_anisotropy=True - ): + def load_fragment_graphs(self, brain_id, swc_pointer, use_anisotropy=True): """ Loads fragments containing merge mistakes for a whole-brain dataset, then stores them in the "graphs" attribute. @@ -139,7 +138,7 @@ def load_fragment_graphs( graph = SkeletonGraph( anisotropy=self.anisotropy, node_spacing=self.node_spacing, - use_anisotropy=use_anisotropy + use_anisotropy=use_anisotropy, ) graph.load(swc_pointer) @@ -771,6 +770,7 @@ def generate_negative_examples(self): negative_examples : List[dict] List of negative examples collected across all graphs. """ + # Subroutines def add_examples(): """ @@ -924,7 +924,7 @@ def __init__( is_multimodal=False, modality=None, sampler=None, - use_shuffle=True + use_shuffle=True, ): """ Instantiates a MergeSiteDataLoader object. @@ -970,11 +970,11 @@ def __iter__(self): for start in range(0, len(idxs), self.batch_size): end = min(start + self.batch_size, len(idxs)) if self.is_multimodal and self.modality == "graph": - yield self._load_image_graph_batch(idxs[start: end]) + yield self._load_image_graph_batch(idxs[start:end]) elif self.is_multimodal and self.modality == "pointcloud": - yield self._load_image_pc_batch(idxs[start: end]) + yield self._load_image_pc_batch(idxs[start:end]) else: - yield self._load_image_batch(idxs[start: end]) + yield self._load_image_batch(idxs[start:end]) def _load_image_batch(self, batch_idxs): """ @@ -1084,9 +1084,7 @@ def _load_image_graph_batch(self, idxs): h.append(h_i) x.append(x_i) edge_index.append(edge_index_i) - batches.append( - torch.full((n_i,), i, dtype=torch.long) - ) + batches.append(torch.full((n_i,), i, dtype=torch.long)) node_offset += n_i @@ -1100,7 +1098,7 @@ def _load_image_graph_batch(self, idxs): batch = ml_util.TensorDict( { "img": ml_util.to_tensor(patches), - "graph": (h, x, edge_index, batches) + "graph": (h, x, edge_index, batches), } ) return batch, ml_util.to_tensor(targets) diff --git a/src/neuron_proofreader/merge_proofreading/merge_inference.py b/src/neuron_proofreader/merge_proofreading/merge_inference.py index 73bb319d..474ea3a9 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_inference.py +++ b/src/neuron_proofreader/merge_proofreading/merge_inference.py @@ -250,7 +250,7 @@ def __init__( prefetch=64, segmentation_path=None, subgraph_radius=100, - use_new_mask=False + use_new_mask=False, ): # Call parent class super().__init__() @@ -317,7 +317,9 @@ def find_fragments_to_search(self): for nodes in nx.connected_components(self.graph): # Compute path length node = util.sample_once(list(nodes)) - length = self.graph.cable_length(max_depth=self.min_size, root=node) + length = self.graph.cable_length( + max_depth=self.min_size, root=node + ) # Check if path length satisfies threshold if length > self.min_size: @@ -431,7 +433,7 @@ def __init__( segmentation_path=None, step_size=10, subgraph_radius=100, - use_new_mask=False + use_new_mask=False, ): # Call parent class super().__init__( @@ -445,7 +447,7 @@ def __init__( prefetch=prefetch, segmentation_path=segmentation_path, subgraph_radius=subgraph_radius, - use_new_mask=use_new_mask + use_new_mask=use_new_mask, ) # Instance attributes @@ -537,7 +539,13 @@ def _get_batch(self, nodes, img, offset): patch_centers = self.get_patch_centers(nodes) - offset # Populate batch array - batch = np.empty((len(nodes), 2,) + self.patch_shape) + batch = np.empty( + ( + len(nodes), + 2, + ) + + self.patch_shape + ) for i, center in enumerate(patch_centers): s = img_util.get_slices(center, self.patch_shape) batch[i, 0, ...] = img_util.normalize(img[s]) @@ -550,7 +558,13 @@ def _get_multimodal_batch(self, nodes, img, offset): patch_centers = self.get_patch_centers(nodes) - offset # Populate batch array - patches = np.empty((len(nodes), 2,) + self.patch_shape) + patches = np.empty( + ( + len(nodes), + 2, + ) + + self.patch_shape + ) point_clouds = np.empty((len(nodes), 3, 3600), dtype=np.float32) for i, (node, center) in enumerate(zip(nodes, patch_centers)): s = img_util.get_slices(center, self.patch_shape) @@ -610,7 +624,7 @@ def __init__( prefetch=128, segmentation_path=None, subgraph_radius=100, - use_new_mask=False + use_new_mask=False, ): # Call parent class super().__init__( @@ -623,7 +637,7 @@ def __init__( prefetch=prefetch, segmentation_path=segmentation_path, subgraph_radius=subgraph_radius, - use_new_mask=use_new_mask + use_new_mask=use_new_mask, ) # Instance attributes diff --git a/src/neuron_proofreader/proposal_graph.py b/src/neuron_proofreader/proposal_graph.py index 92b624fd..37382210 100644 --- a/src/neuron_proofreader/proposal_graph.py +++ b/src/neuron_proofreader/proposal_graph.py @@ -85,7 +85,8 @@ def __init__( self.reset_proposals() self.proposal_generator = ProposalGenerator( - self, max_proposals_per_leaf=max_proposals_per_leaf, + self, + max_proposals_per_leaf=max_proposals_per_leaf, ) # Graph Loader @@ -138,7 +139,7 @@ def generate_proposals( self, search_radius, allow_nonleaf_proposals=False, - min_size_with_proposals=0 + min_size_with_proposals=0, ): """ Generates proposals from leaf nodes. diff --git a/src/neuron_proofreader/skeleton_graph.py b/src/neuron_proofreader/skeleton_graph.py index 32fc7b4d..eb75f46d 100644 --- a/src/neuron_proofreader/skeleton_graph.py +++ b/src/neuron_proofreader/skeleton_graph.py @@ -330,7 +330,7 @@ def remove_soma_merges(self): self.relabel_nodes() results = [ f"# Soma Fragments: {len(self.soma_centroids)}", - f"# Soma Merges: {n_soma_merges}" + f"# Soma Merges: {n_soma_merges}", ] return "\n".join(results) diff --git a/src/neuron_proofreader/split_proofreading/proposal_generation.py b/src/neuron_proofreader/split_proofreading/proposal_generation.py index 4f406673..729597cb 100644 --- a/src/neuron_proofreader/split_proofreading/proposal_generation.py +++ b/src/neuron_proofreader/split_proofreading/proposal_generation.py @@ -56,7 +56,7 @@ def __call__( self, initial_radius, allow_nonleaf_proposals=False, - min_size_with_proposals=0 + min_size_with_proposals=0, ): """ Generates edge proposals between fragments within the given search diff --git a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py index cef95a29..87e8fcd1 100644 --- a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py +++ b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py @@ -62,7 +62,7 @@ def __init__( brightness_clip=brightness_clip, patch_shape=patch_shape, padding=padding, - ), + ), ] def __call__(self, subgraph): @@ -249,7 +249,7 @@ def __init__( self.padding = padding # Image reader - self.img = img_util.TensorStoreReader(img_path) + self.img = img_util.TensorStoreImage(img_path) def __call__(self, subgraph, features): """ @@ -329,7 +329,6 @@ def create_segment_mask(self, proposal, shape, offset): img_util.annotate_voxels(mask, voxels, val=0.25) visited.add(frozenset({i, j})) return mask - def read_image(self, center, shape): """ diff --git a/src/neuron_proofreader/split_proofreading/split_inference.py b/src/neuron_proofreader/split_proofreading/split_inference.py index b3bfcc7a..ed84f02b 100644 --- a/src/neuron_proofreader/split_proofreading/split_inference.py +++ b/src/neuron_proofreader/split_proofreading/split_inference.py @@ -175,8 +175,12 @@ def __call__( while self.dataset.proposals: # Generate predictons cnt += 1 - self.log(f"\nThreshold={new_threshold} w/ only_leaf2leaf={only_leaf2leaf}") - preds = self.predict_proposals(suffix=f"{name}_round={cnt}_threshold={new_threshold}") + self.log( + f"\nThreshold={new_threshold} w/ only_leaf2leaf={only_leaf2leaf}" + ) + preds = self.predict_proposals( + suffix=f"{name}_round={cnt}_threshold={new_threshold}" + ) # Merge accetped proposals cur_threshold = new_threshold diff --git a/src/neuron_proofreader/utils/graph_util.py b/src/neuron_proofreader/utils/graph_util.py index f01c0d50..2fbdf4dd 100644 --- a/src/neuron_proofreader/utils/graph_util.py +++ b/src/neuron_proofreader/utils/graph_util.py @@ -7,19 +7,6 @@ Code that loads and preprocesses neuron fragments stored as SWC files, then constructs a custom graph object called a "FragmentsGraph". - Graph Loading Algorithm: - 1. Load Soma Locations (Optional) - - 2. Extract Irreducibles from SWC files - a. Build graph from SWC file - b. Break soma merges (optional) - c. Break high risk merges (optional) - d. Find irreducible nodes - e. Find irreducible edges - - -Note: We use the term "branch" to refer to a path in a graph from a branching - node to a leaf. """ from collections import deque @@ -80,7 +67,7 @@ def __init__( self.node_spacing = node_spacing self.prefetch = prefetch self.prune_depth = prune_depth - self.swc_reader = swc_util.Reader(anisotropy, min_cable_length, verbose) + self.swc_reader = swc_util.Reader(anisotropy, verbose) self.verbose = verbose def __call__(self, swc_pointer): @@ -146,7 +133,7 @@ def load(self, swc_dict): subgraph. """ # Build graph - graph = swc_util.to_graph(swc_dict, set_attrs=True) + graph = swc_util.to_graph(swc_dict) prune_branches(graph, self.prune_depth) # Extract irreducible components (if applicable) diff --git a/src/neuron_proofreader/utils/img_util.py b/src/neuron_proofreader/utils/img_util.py index ab5d50e9..04d74daf 100644 --- a/src/neuron_proofreader/utils/img_util.py +++ b/src/neuron_proofreader/utils/img_util.py @@ -8,7 +8,6 @@ """ -from abc import ABC, abstractmethod from fastremap import mask_except, renumber, unique from matplotlib.colors import ListedColormap from scipy.ndimage import zoom @@ -21,137 +20,29 @@ from neuron_proofreader.utils import util -class ImageReader(ABC): +class TensorStoreImage: """ - Abstract class to create image readers classes. + Class that reads images with the TensorStore library. """ def __init__(self, img_path): """ - Instantiates an ImageReader object. + Instantiates a TensorStoreImage object. Parameters ---------- img_path : str Path to image. - is_segmentation : bool, optional - Indication of whether image is a segmentation. """ - self.img_path = img_path - self._load_image() - - @abstractmethod - def _load_image(self): - """ - Method to be implemented by subclasses to load the image. - """ - pass - - def read(self, center, shape): - """ - Reads an image patch centered at the given voxel coordinate. - - Parameters - ---------- - center : Tuple[int] - Center of image patch to be read. - shape : Tuple[int] - Shape of image patch to be read. - - Returns - ------- - numpy.ndarray - Image patch. - """ - s = get_slices(center, shape) - return self.img[s] if self.img.ndim == 3 else self.img[(0, 0, *s)] - - def read_voxel(self, voxel, thread_id=None): - """ - Reads the intensity value at a given voxel. - - Parameters - ---------- - voxel : Tuple[int] - Voxel to be read. - thread_id : Any - Identifier associated with output. Default is None. - - Returns - ------- - int - Intensity value at voxel. - """ - return thread_id, self.img[voxel] - - def shape(self): - """ - Gets the shape of image. - - Returns - ------- - Tuple[int] - Shape of image. - """ - return self.img.shape - - -class TensorStoreReader(ImageReader): - """ - Class that reads images with TensorStore library. - """ - - def __init__(self, img_path): - """ - Instantiates a TensorStoreReader object. - - Parameters - ---------- - img_path : str - Path to image. - """ - self.driver = self.get_driver(img_path) - super().__init__(img_path) - - def get_driver(self, img_path): - """ - Gets the storage driver needed to read the image. - - Parameters - ---------- - img_path : str - Path to image - - Returns - ------- - str - Storage driver needed to read the image. - """ - if ".zarr" in img_path: - return "zarr" - elif ".n5" in img_path: - return "n5" - elif is_precomputed(img_path): - return "neuroglancer_precomputed" - else: - raise ValueError(f"Unsupported image format: {img_path}") - - def _load_image(self): - """ - Loads image with TensorStore library. - """ - # Extract metadata - bucket_name, path = util.parse_cloud_path(self.img_path) - storage_driver = get_storage_driver(self.img_path) - # Load image + bucket_name, inner_path = util.parse_cloud_path(img_path) self.img = ts.open( { - "driver": self.driver, + "driver": get_driver(img_path), "kvstore": { - "driver": storage_driver, + "driver": get_storage_driver(img_path), "bucket": bucket_name, - "path": path, + "path": inner_path, }, "context": { "cache_pool": {"total_bytes_limit": 1000000000}, @@ -162,19 +53,21 @@ def _load_image(self): } ).result() - # Check whether to absorb channel - if bucket_name == "allen-nd-goog" and is_precomputed(self.img_path): - self.img = self.img[ts.d["channel"][0]] - self.img = self.img[ts.d[0].transpose[2]] - self.img = self.img[ts.d[0].transpose[1]] + # Check for Google segmentation + if "from_google" in img_path: + self.img = self.img[ts.d[:].transpose[3, 2, 1, 0]] - def read(self, center, shape): + # Check dimensions + while self.img.ndim < 5: + self.img = self.img[ts.newaxis, ...] + + def read(self, voxel, shape): """ - Reads an image patch centered at the given voxel. + Reads a patch from an image given a voxel coordinate and patch shape. Parameters ---------- - center : Tuple[int] + voxel : Tuple[int] Center of image patch to be read. shape : Tuple[int] Shape of image patch to be read. @@ -184,29 +77,18 @@ def read(self, center, shape): numpy.ndarray Image patch. """ - try: - return super().read(center, shape).read().result() - except Exception: - print(f"Unable to read image patch at {center} w/ shape {shape}!") - return np.zeros(shape) + return self.img[(0, 0, *get_slices(voxel, shape))].read().result() - def read_voxel(self, voxel, thread_id): + def shape(self): """ - Reads the intensity value at a given voxel. - - Parameters - ---------- - voxel : Tuple[int] - Voxel to be read. - thread_id : Any - Identifier associated with output. + Gets the shape of image. Returns ------- - int - Intensity value at voxel. + Tuple[int] + Shape of image. """ - return thread_id, int(self.img[voxel].read().result()) + return self.img.shape # --- Visualization --- @@ -443,6 +325,27 @@ def get_contained_voxels(voxels, shape, buffer=0): return [v for v in voxels if is_contained(v, shape, buffer)] +def get_driver(img_path): + """ + Gets the driver needed to read the image. + + Parameters + ---------- + img_path : str + Path to image. + + Returns + ------- + str + Storage driver needed to read the image. + """ + if ".zarr" in img_path: + return "zarr" + elif is_precomputed(img_path): + return "neuroglancer_precomputed" + raise Exception(f"Invalid image path at {img_path}") + + def get_neighbors(voxel, shape): """ Gets the neighbors of a given voxel coordinate. diff --git a/src/neuron_proofreader/utils/swc_util.py b/src/neuron_proofreader/utils/swc_util.py index 419ecaab..6300fd53 100644 --- a/src/neuron_proofreader/utils/swc_util.py +++ b/src/neuron_proofreader/utils/swc_util.py @@ -18,6 +18,8 @@ these attributes in the same order. """ +from botocore import UNSIGNED +from botocore.config import Config from collections import deque from concurrent.futures import ( ProcessPoolExecutor, @@ -31,6 +33,7 @@ from zipfile import ZipFile import ast +import boto3 import networkx as nx import numpy as np import os @@ -41,10 +44,10 @@ class Reader: """ Class that reads SWC files stored in a (1) local directory, (2) local ZIP - archive, or (3) GCS directory, (4) GCS directory of ZIP archives. + archive, and (3) local directory of ZIP archives. """ - def __init__(self, anisotropy=(1.0, 1.0, 1.0), min_size=0, verbose=True): + def __init__(self, anisotropy=(1.0, 1.0, 1.0), verbose=True): """ Initializes a Reader object that reads SWC files. @@ -52,149 +55,140 @@ def __init__(self, anisotropy=(1.0, 1.0, 1.0), min_size=0, verbose=True): ---------- anisotropy : Tuple[float], optional Image to physical coordinates scaling factors to account for the - anisotropy of the microscope. Default is [1.0, 1.0, 1.0]. - min_size : int, optional - Threshold on the number nodes in SWC files that are parsed and - returned. Default is 0. - verbose : bool - Indication of whether to display a progress bar during loading. - Default is True. + anisotropy of the microscope. Default is (1.0, 1.0, 1.0). + verbose : bool, optional + Indication of whether to display a progress bar. Default is True. """ self.anisotropy = anisotropy - self.min_size = min_size self.verbose = verbose + # --- Read Data --- def __call__(self, swc_pointer): """ - Reads SWC files located at the path specified by "swc_pointer". + Loads SWC files based on the type pointer provided. Parameters ---------- - swc_pointer : str or List[str] + swc_pointer : str Object that points to SWC files to be read, must be one of: - file_path: Path to single SWC file - dir_path: Path to local directory with SWC files - zip_path: Path to local ZIP with SWC files - zip_dir_path: Path to local directory of ZIPs with SWC files + - s3_dir_path: Path to S3 prefix with SWC files - gcs_dir_path: Path to GCS prefix with SWC files - gcs_zip_dir_path: Path to GCS prefix with ZIPs of SWC files - - path_list: List of paths to local SWC files Returns ------- Deque[dict] - List of dictionaries whose keys and values are the attribute names - and values from the SWC files. Each dictionary contains the - following items: + Dictionaries whose keys and values are the attribute names and + values from the SWC files. Each dictionary contains the following: + items: - "id": unique identifier of each node in an SWC file. - "pid": parent ID of each node. - "radius": radius value corresponding to each node. - "xyz": coordinate corresponding to each node. - - "soma_nodes": nodes with soma type. - - "swc_name": name of SWC file, minus the ".swc". + - "filename": filename of SWC file + - "swc_id": name of SWC file, minus the ".swc". """ - # List of paths to SWC files + # List of paths if isinstance(swc_pointer, list): - return self.read_from_paths(swc_pointer) + return self.read_swcs(swc_pointer) # Directory containing... if os.path.isdir(swc_pointer): - # ZIP archives with SWC files + # Local ZIP archives with SWC files paths = util.list_paths(swc_pointer, extension=".zip") if len(paths) > 0: - return self.read_from_zips(swc_pointer) + return self.read_zips(swc_pointer, self.read_zip) - # SWC files - paths = util.list_paths(swc_pointer, extension=".swc") + # Local SWC files + paths = util.read_paths(swc_pointer, extension=".swc") if len(paths) > 0: - return self.read_from_paths(paths) + return self.read_swcs(paths) - raise Exception(f"Directory is invalid - {swc_pointer}") + raise Exception("Directory is Invalid!") # Path to... if isinstance(swc_pointer, str): - # Single SWC file in GCS - if util.is_gcs_path(swc_pointer) and swc_pointer.endswith(".swc"): - bucket_name, path = util.parse_cloud_path(swc_pointer) - return [self.read_from_gcs_swc(bucket_name, path)] - - # GCS directory - if util.is_gcs_path(swc_pointer): - return self.read_from_gcs(swc_pointer) + # Cloud GCS/S3 storage + if util.is_gcs_path(swc_pointer) or util.is_s3_path(swc_pointer): + return self.read_from_cloud(swc_pointer) - # ZIP archive with SWC files + # Local ZIP archive with SWC files if swc_pointer.endswith(".zip"): - return self.read_from_zip(swc_pointer) + return self.read_zip(swc_pointer) - # Single SWC file + # Local path to single SWC file if swc_pointer.endswith(".swc"): - return self.read_from_path(swc_pointer) + return self.read_swc(swc_pointer) - raise Exception(f"Path is invalid {swc_pointer}") + raise Exception("Path is Invalid!") - raise Exception(f"SWC Pointer is invalid {swc_pointer}") + raise Exception("SWC Pointer is Invalid!") - # --- Read subroutines --- - def read_from_paths(self, swc_paths): + def read_swc(self, path): """ - Reads a list of SWC files stored on the local machine. + Reads a single SWC file. Paramters --------- - swc_paths : List[str] - Paths to SWC files stored on the local machine. + path : str + Path to SWC file. Returns ------- - swc_dicts : Dequeue[dict] - List of dictionaries whose keys and values are the attribute - names and values from an SWC file. + dict + Dictionary whose keys and values are the attribute names and + values from an SWC file. """ - with ProcessPoolExecutor() as executor: - # Assign processes - processes = list() - for path in swc_paths: - processes.append(executor.submit(self.read_from_path, path)) + content = util.read_txt(path).splitlines() + filename = os.path.basename(path) + return self.parse(content, filename) - # Store results - swc_dicts = deque() - for process in as_completed(processes): - result = process.result() - if result: - swc_dicts.append(result) - return swc_dicts - - def read_from_path(self, path): + def read_swcs(self, swc_paths): """ - Reads a single SWC file stored on the local machine. + Reads SWC files stored in a GCS or S3 bucket. - Paramters - --------- - path : str - Path to SWC file stored on the local machine. + Parameters + ---------- + swc_paths : List[str] + List of paths to SWC files to be read. Returns ------- - swc_dict : dict + swc_dicts : Deque[dict] Dictionaries whose keys and values are the attribute names and values from an SWC file. """ - content = util.read_txt(path) - if len(content) > self.min_size - 10: - swc_dict = self.parse(content) - swc_dict["swc_name"] = get_swc_name(path) - return swc_dict - else: - return False + with ThreadPoolExecutor() as executor: + # Assign threads + threads = set() + for path in swc_paths: + threads.add(executor.submit(self.read_swc, path)) + + # Store results + swc_dicts = deque() + pbar = self.manual_progress_bar(len(threads)) + for thread in as_completed(threads): + result = thread.result() + if result: + swc_dicts.append(result) + if self.verbose: + pbar.update(1) + return swc_dicts - def read_from_zips(self, zip_dir): + def read_zips(self, zip_paths, read_fn): """ - Reads a directory containing ZIP archives with SWC files. + Reads SWC files stored in ZIP archives. Parameters ---------- - zip_dir : str - Path to directory containing ZIP archives with SWC files. + bucket_name : str + Name of bucket containing SWC files. + zip_paths : List[str] + Paths to ZIP archives containing SWC files to be read. Returns ------- @@ -202,273 +196,251 @@ def read_from_zips(self, zip_dir): Dictionaries whose keys and values are the attribute names and values from an SWC file. """ - # Initializations - zip_names = [f for f in os.listdir(zip_dir) if f.endswith(".zip")] - iterator = zip_names - if self.verbose: - pbar = tqdm(iterator, desc="Read SWCs") - - # Main + pbar = self.manual_progress_bar(len(zip_paths)) with ProcessPoolExecutor() as executor: - # Assign threads - processes = list() - for f in iterator: - zip_path = os.path.join(zip_dir, f) - processes.append(executor.submit(self.read_from_zip, zip_path)) + # Assign processes + futures = {executor.submit(read_fn, path) for path in zip_paths} # Store results swc_dicts = deque() - for process in as_completed(processes): - swc_dicts.extend(process.result()) + for process in as_completed(futures): + try: + swc_dicts.extend(process.result()) + except RefreshError: + pass + if self.verbose: pbar.update(1) return swc_dicts - def read_from_zip(self, zip_path): + def read_zip(self, zip_path): """ - Reads SWC files from a ZIP archive stored on the local machine. + Reads SWC files from a ZIP archive. Paramters --------- - str : str - Path to a ZIP archive on the local machine. + zip_path : str + Path to ZIP archive. Returns ------- - swc_dicts : Dequeue[dict] - List of dictionaries whose keys and values are the attribute - names and values from an SWC file. + swc_dicts : Deque[dict] + Dictionaries whose keys and values are the attribute names and + values from an SWC file. """ with ThreadPoolExecutor() as executor: - with ZipFile(zip_path, "r") as zf: - # Submit threads - threads = list() - for f in [f for f in zf.namelist() if f.endswith(".swc")]: - threads.append( - executor.submit(self.read_from_zipped_file, zf, f) - ) + # Assign threads + threads = set() + zf = ZipFile(zip_path, "r") + for name in [f for f in zf.namelist() if f.endswith(".swc")]: + threads.add(executor.submit(self.read_zipped_swc, zf, name)) - # Store results - swc_dicts = deque() - for thread in as_completed(threads): - swc_dict = thread.result() - if swc_dict: - swc_dicts.append(swc_dict) + # Store results + swc_dicts = deque() + for thread in as_completed(threads): + result = thread.result() + if result: + swc_dicts.append(result) return swc_dicts - def read_from_zipped_file(self, zip_file, path): + def read_zipped_swc(self, zipfile, path): """ - Reads SWC file stored in a ZIP archive. + Reads an SWC file stored in a ZIP archive. Parameters ---------- - zip_file : ZipFile - ZIP archive containing SWC file to be read. + zipfile : ZipFile + ZIP archive containing SWC files. path : str - Path to SWC file to be read. + Path to SWC file. Returns ------- - swc_dict : dict - Dictionaries whose keys and values are the attribute names and + dict + Dictionary whose keys and values are the attribute names and values from an SWC file. """ - content = util.read_zip(zip_file, path).splitlines() - if len(content) > self.min_size - 10: - swc_dict = self.parse(content) - swc_dict["swc_name"] = get_swc_name(path) - return swc_dict - else: - return False + content = util.read_zip(zipfile, path).splitlines() + filename = os.path.basename(path) + return self.parse(content, filename) - def read_from_gcs(self, gcs_path): + def read_from_cloud(self, path): """ - Reads SWC files stored in a GCS bucket. + Reads SWC files stored in a GCS or S3 bucket. Parameters ---------- - gcs_path : str - Path to SWC files located in a GCS bucket. + path : str + Path to location in a GCS or S3 bucket containing SWC files, + must be in the format "{scheme}://{bucket_name}/{prefix}". Returns ------- - Dequeue[dict] - List of dictionaries whose keys and values are the attribute - names and values from an SWC file. + Deque[dict] + Dictionaries whose keys and values are the attribute names and + values from an SWC file. """ - # List filenames - bucket_name, prefix = util.parse_cloud_path(gcs_path) - swc_paths = util.list_gcs_filenames(bucket_name, prefix, ".swc") - zip_paths = util.list_gcs_filenames(bucket_name, prefix, ".zip") + # Extact info + assert util.is_s3_path(path) or util.is_gcs_path(path) + use_s3 = util.is_s3_path(path) + + # List paths + swc_paths = util.list_cloud_paths(path, ".swc") + zip_paths = util.list_cloud_paths(path, ".zip") # Call reader - if len(swc_paths) > 0: - return self.read_from_gcs_swcs(bucket_name, swc_paths) - if len(zip_paths) > 0: - return self.read_from_gcs_zips(bucket_name, zip_paths) + if swc_paths: + return self.read_swcs(swc_paths) + elif zip_paths: + read_fn = self.read_s3_zip if use_s3 else self.read_gcs_zip + return self.read_zips(zip_paths, read_fn) - # Error - raise Exception(f"GCS Pointer is invalid {gcs_path}") + raise Exception(f"SWC Pointer is invalid {path}") - def read_from_gcs_swcs(self, bucket_name, swc_paths): + def read_gcs_swc(self, path): """ - Reads SWC files stored in a GCS bucket. + Reads a single SWC file stored in a GCS bucket. Parameters ---------- - bucket_name : str - Name of GCS bucket containing SWC files. - swc_paths : List[str] - Paths to SWC files. + path : List[str] + Path to SWC file to be read. Returns ------- - swc_dicts : Dequeue[dict] - List of dictionaries whose keys and values are the attribute - names and values from an SWC file. + Deque[dict] + Dictionaries whose keys and values are the attribute names and + values from an SWC file. """ - if self.verbose: - pbar = tqdm(total=len(swc_paths), desc="Read SWCs") - - with ThreadPoolExecutor() as executor: - # Assign threads - threads = list() - for path in swc_paths: - threads.append( - executor.submit(self.read_from_gcs_swc, bucket_name, path) - ) + # Initialize cloud reader + bucket_name, subpath = util.parse_cloud_path(path) + bucket = storage.Client().bucket(bucket_name) + blob = bucket.blob(subpath) - # Store results - swc_dicts = deque() - for thread in as_completed(threads): - result = thread.result() - if result: - swc_dicts.append(result) - if self.verbose: - pbar.update(1) - return swc_dicts + # Parse swc contents + content = blob.download_as_text().splitlines() + filename = os.path.basename(subpath) + return self.parse(content, filename) - def read_from_gcs_swc(self, bucket_name, path): + def read_gcs_zip(self, path): """ - Reads a single SWC file stored in a GCS bucket. + Reads SWC files stored in a ZIP archive downloaded from a GCS + bucket. Parameters ---------- - bucket_name : str - Name of GCS bucket containing SWC files. - swc_path : str - Path to SWC file to be read. + path : str + Path to ZIP archive containing SWC files to be read. Returns ------- - swc_dict : dict + swc_dicts : Deque[dict] Dictionaries whose keys and values are the attribute names and values from an SWC file. """ - # Initialize cloud reader - client = storage.Client() - bucket = client.bucket(bucket_name) - blob = bucket.blob(path) + # Download ZIP + bucket_name, path = util.parse_cloud_path(path) + bucket = storage.Client().bucket(bucket_name) + try: + zip_content = bucket.blob(path).download_as_bytes() + except TransportError: + print(f"Failed to read {path}!") + return deque() - # Parse swc contents - content = blob.download_as_text().splitlines() - if len(content) > self.min_size - 10: - swc_dict = self.parse(content) - swc_dict["swc_name"] = get_swc_name(path) - return swc_dict - else: - return False + # Parse ZIP + swc_dicts = deque() + zip_content = bucket.blob(path).download_as_bytes() + with ZipFile(BytesIO(zip_content), "r") as zf: + with ThreadPoolExecutor() as executor: + # Assign threads + threads = set() + for name in zf.namelist(): + threads.add( + executor.submit(self.read_zipped_swc, zf, name) + ) + + # Process results + for thread in as_completed(threads): + result = thread.result() + if result: + swc_dicts.append(result) + return swc_dicts - def read_from_gcs_zips(self, bucket_name, zip_paths): + def read_s3_zip(self, path): """ - Reads SWC files from ZIP archives stored in a GCS bucket. + Reads SWC files stored in a ZIP archive downloaded from an S3 + bucket. Parameters ---------- - bucket_name : str - Name of GCS bucket containing SWC files. - zip_paths : List[str] - Paths to ZIP archives in a GCS bucket. + path : str + Path to ZIP archive containing SWC files to be read. Returns ------- - swc_dicts : Dequeue[dict] - List of dictionaries whose keys and values are the attribute - names and values from an SWC file. + swc_dicts : Deque[dict] + Dictionaries whose keys and values are the attribute names and + values from an SWC file. """ - # Initializations - batch_size = 1000 - if self.verbose: - pbar = tqdm(total=len(zip_paths), desc="Read SWCs") - - # Main - swc_dicts = deque() - with ProcessPoolExecutor() as executor: - for i in range(0, len(zip_paths), batch_size): - # Assign processes - processes = list() - for zip_path in zip_paths[i: i + batch_size]: - processes.append( - executor.submit( - self.read_from_gcs_zip, bucket_name, zip_path - ) + # Initialize cloud reader + bucket, key = util.parse_cloud_path(path) + s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED)) + zip_content = s3.get_object(Bucket=bucket, Key=key)["Body"].read() + + # Parse ZIP + with ZipFile(BytesIO(zip_content), "r") as zf: + with ThreadPoolExecutor() as executor: + # Assign threads + threads = set() + for name in zf.namelist(): + threads.add( + executor.submit(self.read_zipped_swc, zf, name) ) # Store results - for process in as_completed(processes): - try: - swc_dicts.extend(process.result()) - except RefreshError: - pass - if self.verbose: - pbar.update(1) + swc_dicts = deque() + for thread in as_completed(threads): + result = thread.result() + if result: + swc_dicts.append(result) return swc_dicts - def read_from_gcs_zip(self, bucket_name, zip_path, filenames=None): + # -- Process Text --- + def iterator(self, iterator): """ - Reads SWC files stored in a ZIP archive downloaded from a cloud - bucket. + Gets an iterator that optionally displays a progress bar. Parameters ---------- - bucket_name : str - Name of GCS bucket containing SWC files. - zip_path : str - Path to ZIP archive to be read. - filenames : None or List[str], optional - Filenames to be read if provided. Default is None. + iterator : iterable + Object to be iterated over. Returns ------- - swc_dicts : Dequeue[dict] - List of dictionaries whose keys and values are the attribute - names and values from an SWC file. + tqdm.tqdm + Iterator that is optionally wrapped in a progress bar. """ - try: - # Download zip - client = storage.Client() - bucket = client.bucket(bucket_name) - zip_content = bucket.blob(zip_path).download_as_bytes() - except TransportError: - print(f"Failed to read {zip_path}!") - return deque() + return tqdm(iterator, desc="Read SWCs") if self.verbose else iterator - # Process files - swc_dicts = deque() - with ZipFile(BytesIO(zip_content), "r") as zip_file: - filenames = zip_file.namelist() if filenames is None else filenames - for filename in filenames: - result = self.read_from_zipped_file(zip_file, filename) - if result: - swc_dicts.append(result) - return swc_dicts + def manual_progress_bar(self, total): + """ + Gets progress bar that needs to be updated manually. + + Parameters + ---------- + total : int + Size of progress bar. + + Returns + ------- + tqdm.tqdm + Iterator that is optionally wrapped in a progress bar. + """ + return tqdm(total=total, desc="Read SWCs") if self.verbose else None - # --- Process content --- - def parse(self, content): + def parse(self, content, filename): """ Parses an SWC file to extract the content which is stored in a dict. - Note that node_ids from SWC are reindex from 0 to n-1 where n is the - number of nodes in the SWC file. Parameters ---------- @@ -477,34 +449,40 @@ def parse(self, content): Returns ------- - swc_dict : dict - Dictionaries whose keys and values are the attribute names - and values from an SWC file. + dict + Dictionary whose keys and values are the attribute names and + values from an SWC file. """ # Initializations + swc_name, _ = os.path.splitext(filename) content, offset = self.process_content(content) - swc_dict = { - "id": np.zeros((len(content)), dtype=int), - "radius": np.zeros((len(content)), dtype=np.float16), - "pid": np.zeros((len(content)), dtype=int), - "xyz": np.zeros((len(content), 3), dtype=np.float32), - "soma_nodes": set(), - } - - # Parse content - for i, line in enumerate(content): - parts = line.split() - swc_dict["id"][i] = parts[0] - swc_dict["radius"][i] = float(parts[-2]) - swc_dict["pid"][i] = parts[-1] - swc_dict["xyz"][i] = self.read_xyz(parts[2:5], offset) - if int(parts[1]) == 1: - swc_dict["soma_nodes"].add(parts[0]) - - # Convert radius from nanometers to microns - if swc_dict["radius"][0] > 100: - swc_dict["radius"] /= 1000 - return swc_dict + if len(content) > 0: + swc_dict = { + "id": np.zeros((len(content)), dtype=int), + "pid": np.zeros((len(content)), dtype=int), + "radius": np.zeros((len(content)), dtype=float), + "xyz": np.zeros((len(content), 3), dtype=np.int32), + "soma_nodes": set(), + "swc_name": swc_name, + } + + # Parse content + for i, line in enumerate(content): + parts = line.split() + swc_dict["id"][i] = parts[0] + swc_dict["pid"][i] = parts[-1] + swc_dict["radius"][i] = float(parts[-2]) + swc_dict["xyz"][i] = self.read_coordinate(parts[2:5], offset) + + if int(parts[1]) == 1: + swc_dict["soma_nodes"].add(parts[0]) + + # Convert radius from nanometers to microns + if swc_dict["radius"][0] > 100: + swc_dict["radius"] /= 1000 + return swc_dict + else: + return None def process_content(self, content): """ @@ -520,33 +498,33 @@ def process_content(self, content): Returns ------- content : List[str] - List of strings representing the lines of text starting from the - line immediately after the last commented line. - offset : List[float] - Offset used to shift coordinates. + Lines from an SWC file after comments. + offset : Tuple[int] + Offset used to shift coordinate. """ offset = (0, 0, 0) for i, line in enumerate(content): if line.startswith("# OFFSET"): - offset = self.read_xyz(line.split()[2:5]) - if not line.startswith("#") and len(line) > 0: + parts = line.split() + offset = self.read_coordinate(parts[2:5]) + if not line.startswith("#") and len(line.strip()) > 0: return content[i:], offset - def read_xyz(self, xyz_str, offset=(0, 0, 0)): + def read_coordinate(self, xyz_str, offset=(0, 0, 0)): """ - Reads a 3D coordinate from a string and transforms it. + Reads a coordinate from a string and converts it to voxel coordinates. Parameters ---------- xyz_str : str - Coordinate stored as a str. - offset : List[float], optional - Shift applied to coordinate. Default is (0, 0, 0). + Coordinate stored as a string. + offset : Tuple[int] + Offset of coordinates in SWC file. Default is (0, 0, 0). Returns ------- - List[float] - Coordinate of node from an SWC file. + Tuple[int] + xyz coordinates of an entry from an SWC file. """ iterator = zip(self.anisotropy, xyz_str, offset) return [a * (float(s) + o) for a, s, o in iterator] @@ -575,21 +553,21 @@ def write_points( radius : float, optional Radius to be used in SWC file. Default is 10. """ - zip_writer = ZipFile(zip_path, write_mode) + zf = ZipFile(zip_path, write_mode) for i, xyz in enumerate(points): filename = prefix + str(i + 1) + ".swc" - to_zipped_point(zip_writer, filename, xyz, color=color, radius=radius) + to_zipped_point(zf, filename, xyz, color=color, radius=radius) -def to_zipped_point(zip_writer, filename, xyz, color=None, radius=5): +def to_zipped_point(zf, filename, xyz, color=None, radius=5): """ Writes a point to an SWC file format, which is then stored in a ZIP archive. Parameters ---------- - zip_writer : zipfile.ZipFile - ZipFile object that will store the generated SWC file. + zf : zipfile.ZipFile + ZipFile used to write the generated SWC file. filename : str Filename of SWC file. xyz : ArrayLike @@ -597,7 +575,7 @@ def to_zipped_point(zip_writer, filename, xyz, color=None, radius=5): color : str, optional Color of nodes. Default is None. radius : float, optional - Radius of point. Default is 5um. + Radius (in microns) of point. Default is 5. """ with StringIO() as text_buffer: # Preamble @@ -610,7 +588,7 @@ def to_zipped_point(zip_writer, filename, xyz, color=None, radius=5): text_buffer.write("\n" + f"1 5 {x} {y} {z} {radius} -1") # Finish - zip_writer.writestr(filename, text_buffer.getvalue()) + zf.writestr(filename, text_buffer.getvalue()) # --- Helpers --- @@ -621,7 +599,7 @@ def get_segment_id(swc_name): Parameters ---------- swc_name : str - SWC filename, expected to be in the format "{segment_id}.swc". + SWC filename in the format "{segment_id}.swc". Returns ------- @@ -637,7 +615,7 @@ def get_segment_id(swc_name): def get_swc_name(path): """ - Gets name of the SWC file loacted at the given path, minus the extension. + Gets name of the SWC file at the given path, minus the extension. Parameters ---------- @@ -654,9 +632,9 @@ def get_swc_name(path): return name -def to_graph(swc_dict, set_attrs=False): +def to_graph(swc_dict): """ - Converts an SWC dict to a NetworkX graph with reindexed nodes. + Converts an SWC dictionary to a NetworkX graph with reindexed nodes. Parameters ---------- @@ -679,9 +657,10 @@ def to_graph(swc_dict, set_attrs=False): ] # Build graph with reindexed edges - graph = nx.Graph(swc_name=swc_dict["swc_name"]) + graph = nx.Graph( + swc_name=swc_dict["swc_name"], + radius=swc_dict["radius"], + xyz=swc_dict["xyz"], + ) graph.add_edges_from(edges) - if set_attrs: - graph.graph["xyz"] = swc_dict["xyz"] - graph.graph["radius"] = swc_dict["radius"] return graph diff --git a/src/neuron_proofreader/utils/util.py b/src/neuron_proofreader/utils/util.py index d0abe0ce..5cf509b8 100644 --- a/src/neuron_proofreader/utils/util.py +++ b/src/neuron_proofreader/utils/util.py @@ -223,20 +223,25 @@ def read_json(path): def read_txt(path): """ - Reads txt file located at the given path. + Reads txt file at the given path. Parameters ---------- path : str - Path to txt file to be read. + Path to txt file. Returns ------- str - Contents of txt file. + Text from the txt file. """ - with open(path, "r") as f: - return f.read().splitlines() + if is_s3_path(path): + return read_s3_txt(path) + elif is_gcs_path(path): + return read_gcs_txt(path) + else: + with open(path, "r") as f: + return f.read() def read_zip(zip_file, path): @@ -328,7 +333,59 @@ def write_txt(path, contents): f.close() -# --- GCS utils --- +# --- Cloud Utils --- +def list_cloud_paths(path, extension=""): + """ + Lists all files in a GCS/S3 bucket with the given extension. + + Parameters + ---------- + path : str + Path to cloud prefix to be searched, must be in the format: + f"{scheme}://{bucket_name}/{prefix}". + extension : str, optional + File extension of filenames to be listed. Default is an empty string. + + Returns + ------- + List[str] + Filenames stored at the GCS path with the given extension. + """ + assert is_gcs_path(path) or is_s3_path(path) + bucket_name, prefix = parse_cloud_path(path) + list_fn = list_gcs_paths if is_gcs_path(path) else list_s3_paths + return list_fn(bucket_name, prefix, extension=extension) + + +def parse_cloud_path(path): + """ + Parses a cloud storage path into its bucket name and key/prefix. Supports + paths of the form: "{scheme}://bucket_name/prefix" or without a scheme. + + Parameters + ---------- + path : str + Path to be parsed. + + Returns + ------- + bucket_name : str + Name of the bucket. + prefix : str + Cloud prefix. + """ + # Remove s3:// or gs:// if present + if path.startswith("s3://") or path.startswith("gs://"): + path = path[len("s3://"):] + + # Split path + parts = path.split("/", 1) + bucket_name = parts[0] + prefix = parts[1] if len(parts) > 1 else "" + return bucket_name, prefix + + +# --- GCS Utils --- def check_gcs_file_exists(bucket_name, path): """ Checks if the given path exists. @@ -368,14 +425,14 @@ def is_gcs_path(path): return path.startswith("gs://") -def list_gcs_filenames(bucket_name, prefix, extension=""): +def list_gcs_paths(bucket_name, prefix, extension=""): """ - Lists all files in a GCS bucket with the given extension. + Lists paths at a GCS prefix with the given extension. Parameters ---------- bucket_name : str - Name of bucket to be searched. + Name of bucket containing prefix. prefix : str Path to location within bucket to be searched. extension : str, optional @@ -384,11 +441,14 @@ def list_gcs_filenames(bucket_name, prefix, extension=""): Returns ------- List[str] - Filenames stored at the GCS path with the given extension. + Paths under the GCS prefix with the given extension. """ bucket = storage.Client().bucket(bucket_name) - blobs = bucket.list_blobs(prefix=prefix) - return [blob.name for blob in blobs if extension in blob.name] + paths = list() + for name in [b.name for b in bucket.list_blobs(prefix=prefix)]: + if extension in name: + paths.append(os.path.join(f"gs://{bucket_name}", name)) + return paths def list_gcs_subdirectories(bucket_name, prefix): @@ -425,6 +485,25 @@ def list_gcs_subdirectories(bucket_name, prefix): return subdirs +def read_gcs_txt(path): + """ + Reads a txt file stored in a GCS bucket. + + Parameters + ---------- + path : str + Path to txt file to be read. + + Returns + ------- + str + Contents of txt file. + """ + bucket_name, subpath = parse_cloud_path(path) + bucket = storage.Client().bucket(bucket_name) + return bucket.blob(subpath).download_as_text() + + def read_json_from_gcs(bucket_name, blob_path): """ Reads JSON file stored in a GCS bucket. @@ -447,7 +526,7 @@ def read_json_from_gcs(bucket_name, blob_path): return json.loads(blob.download_as_text()) -# --- S3 utils --- +# --- S3 Utils --- def is_s3_path(path): """ Checks if the path is an S3 path. @@ -496,6 +575,60 @@ def list_s3_prefixes(bucket_name, prefix): return list() +def list_s3_paths(bucket_name, prefix, extension=""): + """ + Lists all object keys in a public S3 bucket under a given prefix, + optionally filters by file extension. + + Parameters + ---------- + bucket_name : str + Name of the S3 bucket. + prefix : str + Prefix to search under. + extension : str, optional + File extension to filter by. Default is an empty string. + + Returns + ------- + paths : List[str] + S3 object keys that match the prefix and extension filter. + """ + # Create an anonymous client for public buckets + s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED)) + response = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix) + + # List all objects under the prefix + paths = list() + if "Contents" in response: + for obj in response["Contents"]: + filename = obj["Key"] + if filename.endswith(extension): + path = os.path.join(f"s3://{bucket_name}", filename) + paths.append(path) + return paths + + +def read_s3_txt(path): + """ + Reads a txt file stored in an S3 bucket. + + Parameters + ---------- + path : str + Path to txt file to be read. + + Returns + ------- + str + Contents of txt file. + """ + bucket_name, subpath = parse_cloud_path(path) + s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED)) + obj = s3.get_object(Bucket=bucket_name, Key=subpath) + return obj["Body"].read().decode("utf-8") + + def upload_dir_to_s3(dir_path, bucket_name, prefix): """ Uploads a directory on the local machine to an S3 bucket. @@ -618,34 +751,6 @@ def numpy_to_hashable(arr): return [tuple(item) for item in arr.tolist()] -def parse_cloud_path(path): - """ - Parses a cloud storage path into its bucket name and key/prefix. Supports - paths of the form: "{scheme}://bucket_name/prefix" or without a scheme. - - Parameters - ---------- - path : str - Path to be parsed. - - Returns - ------- - bucket_name : str - Name of the bucket. - prefix : str - Cloud prefix. - """ - # Remove s3:// or gs:// if present - if path.startswith("s3://") or path.startswith("gs://"): - path = path[len("s3://"):] - - # Split path - parts = path.split("/", 1) - bucket_name = parts[0] - prefix = parts[1] if len(parts) > 1 else "" - return bucket_name, prefix - - def sample_once(my_container): """ Samples a single element from the given container. diff --git a/tests/__init__.py b/tests/__init__.py index d0a85479..024d51f2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,2 +1,3 @@ """Init package""" + __version__ = "0.0.0" From 6bb360cf3b28891cccb789b6d83f36ee8692df4b Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 2 Jun 2026 17:45:17 +0000 Subject: [PATCH 02/11] feat: generate neg exs, site sampling --- .../machine_learning/augmentation.py | 4 +- .../merge_datamodules_v2.py | 293 +++++++++++------- src/neuron_proofreader/utils/swc_util.py | 4 +- src/neuron_proofreader/utils/util.py | 11 +- 4 files changed, 183 insertions(+), 129 deletions(-) diff --git a/src/neuron_proofreader/machine_learning/augmentation.py b/src/neuron_proofreader/machine_learning/augmentation.py index 25429dce..4d49e007 100644 --- a/src/neuron_proofreader/machine_learning/augmentation.py +++ b/src/neuron_proofreader/machine_learning/augmentation.py @@ -209,7 +209,7 @@ class RandomContrast3D: Adjusts the contrast of a 3D image by scaling voxel intensities. """ - def __init__(self, p_low=(0, 80), p_high=(98, 100)): + def __init__(self, p_low=(0, 80), p_high=(99, 100)): """ Initializes a RandomContrast3D transformer. @@ -242,7 +242,7 @@ class RandomNoise3D: Adds random Gaussian noise to a 3D image. """ - def __init__(self, max_std=0.2): + def __init__(self, max_std=0.1): """ Initializes a RandomNoise3D transformer. diff --git a/src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py b/src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py index 585892f4..ea6d25a3 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py @@ -32,6 +32,7 @@ import networkx as nx import numpy as np import os +import pandas as pd import queue import random import threading @@ -54,40 +55,33 @@ ) -# --------------------------------------------------------------------------- -# BrainDataset -# --------------------------------------------------------------------------- - - +# -- Datasets --- class BrainDataset: """ All data and retrieval logic for a single whole-brain dataset. Parameters ---------- - brain_id : str - Unique identifier for this brain. anisotropy : Tuple[float] Voxel-to-physical scaling factors. + brain_id : str + Unique identifier for this brain. brightness_clip : int Maximum raw image intensity before normalisation. - subgraph_radius : int - Radius (um) used when extracting rooted subgraphs. - node_spacing : int - Spacing (um) between neighbouring graph nodes. + node_spacing : float + Spacing (in microns) between neighbouring graph nodes. patch_shape : Tuple[int] Shape of the 3D image patches to extract. - use_segmentation_mask : bool - Whether to overlay a volumetric segmentation when building the - segment mask. + subgraph_radius : float + Radius (in microns) used when extracting rooted subgraphs. """ + giant_component_cable_length = 10**4 + def __init__( self, brain_id, sites_prefix, - pos_site_paths, - neg_site_paths, img_path, anisotropy=(1.0, 1.0, 1.0), brightness_clip=500, @@ -96,28 +90,32 @@ def __init__( patch_shape=(128, 128, 128), probability_random_nonmerge_site=0.5, use_transform=False, - use_segmentation_mask=False, ): # Instance attributes self.anisotropy = anisotropy self.brain_id = brain_id self.brightness_clip = brightness_clip - self.subgraph_radius = subgraph_radius + self.ignore_fragments = set() self.node_spacing = node_spacing self.patch_shape = patch_shape self.probability_random_nonmerge_site = ( probability_random_nonmerge_site ) - self.use_segmentation_mask = use_segmentation_mask + self.subgraph_radius = subgraph_radius # Core data structures self.graph = self.load_fragments(sites_prefix) self.img = img_util.TensorStoreImage(img_path) - self.segmentation_reader = None + self.set_giant_components() - self.nonmerge_sites = self.load_sites(neg_site_paths) - self.merge_sites = self.load_sites(pos_site_paths) + self.merge_sites = self.load_sites( + os.path.join(sites_prefix, "merge_sites/") + ) + self.nonmerge_sites = self.load_sites( + os.path.join(sites_prefix, "nonmerge_sites/") + ) self.set_merge_site_info() + self.set_valid_branching_nodes() # Image augmentation for training self.transform = ImageTransforms() if use_transform else None @@ -132,24 +130,64 @@ def load_fragments(self, sites_prefix): graph.load(os.path.join(sites_prefix, "fragments")) return graph - def load_sites(self, site_paths): + def load_sites(self, sites_prefix): sites = list() swc_reader = swc_util.Reader(verbose=False) - for swc_dict in swc_reader(site_paths): - sites.append(swc_dict["xyz"]) - return np.vstack(sites) + for swc_dict in swc_reader(sites_prefix): + xyz = swc_dict["xyz"][0] + dd, ii = self.graph.kdtree.query(xyz) + sites.append( + {"xyz": xyz, "node": ii, "filename": swc_dict["swc_name"]} + ) + return pd.DataFrame(sites) def set_merge_site_info(self): - # Build kdtree of merge sites - self.merge_sites_kdtree = KDTree(self.merge_sites) + if len(self.merge_sites) > 0: + # Build kdtree of merge sites + xyz_arr = np.vstack(self.merge_sites["xyz"]) + self.merge_sites_kdtree = KDTree(xyz_arr) + + # Store fragment IDs corresponding to merge sites + for xyz in self.merge_sites["xyz"]: + _, ii = self.graph.kdtree.query(xyz) + self.ignore_fragments.add(self.graph.node_component_id[ii]) + + def set_giant_components(self): + for nodes in map(list, nx.connected_components(self.graph)): + # Compute cable length + root = util.sample_once(nodes) + cable_length = self.graph.cable_length( + max_depth=self.giant_component_cable_length, root=root + ) + + # Check if giant component + if cable_length > self.giant_component_cable_length: + self.ignore_fragments.add(self.graph.node_component_id[root]) + + def set_valid_branching_nodes(self): + self.valid_branching_nodes = set() + for node in self.graph.branching_nodes(): + # Reject if high-degree + if self.graph.degree(node) > 3: + continue - # Store fragment IDs corresponding to merge sites - self.fragments_with_merge = set() - for xyz in self.merge_sites: - _, ii = self.graph.kdtree.query(xyz) - self.fragments_with_merge.add(self.graph.node_component_id[ii]) + # Reject if node belongs to fragment with merge + # if self.graph.node_component_id[node] in self.ignore_fragments: + # continue - # --- Site retrieval --- + # Reject if branching and near another branching node + if self._has_nearby_branching(node): + continue + + # Reject if near merge site + dd, _ = self.merge_sites_kdtree.query(self.graph.node_xyz[node]) + if dd < 50: + continue + + # Node is valid + self.valid_branching_nodes.add(node) + + # --- Site Retrieval --- def __getitem__(self, idx): # Get example node, label = self.get_site(idx) @@ -174,49 +212,20 @@ def __getitem__(self, idx): def get_site(self, idx): if idx > 0: - return self.get_merge_site(idx) + return self.merge_sites["node"][idx], 1 elif np.random.random() < self.probability_random_nonmerge_site: return self.get_random_nonmerge_site() elif abs(idx) < len(self.nonmerge_sites): - return self.get_indexed_nonmerge_site(abs(idx)) + return self.nonmerge_sites["node"][abs(idx)], 0 else: return self.get_random_nonmerge_site() - def get_merge_site(self, idx): - _, node = self.graph.kdtree.query(self.merge_sites[idx]) - return node, 1 - - def get_indexed_nonmerge_site(self, idx): - _, node = self.graph.kdtree.query(self.nonmerge_sites[idx]) - return node, 0 - def get_random_nonmerge_site(self): - # Search for valid nonmerge site - branching_nodes = self.graph.branching_nodes() - use_branching = branching_nodes and random.random() < 0.5 - for cnt in range(10**4): - # Sample node - if use_branching: - node = util.sample_once(branching_nodes) - else: - node = util.sample_once(self.graph.nodes) - - # Reject if high-degree - if self.graph.degree(node) > 3: - continue - - # Reject if branching and near another branching node - if use_branching and self._has_nearby_branching(node): - continue - - # Reject if near merge site - dd, _ = self.merge_sites_kdtree.query(self.graph.node_xyz[node]) - if dd < 100: - continue - - # Site is valid - break - + use_branching = self.valid_branching_nodes and random.random() < 0.5 + if use_branching: + node = util.sample_once(self.valid_branching_nodes) + else: + node = util.sample_once(self.graph.nodes) return node, 0 # --- Image / Mask Extraction --- @@ -239,8 +248,7 @@ def get_img_patch(self, center): def get_segment_mask(self, center, subgraph): """ - Builds the segment mask for subgraph, optionally incorporating a - volumetric segmentation read. + Builds the segment mask for subgraph. Parameters ---------- @@ -252,13 +260,7 @@ def get_segment_mask(self, center, subgraph): ------- numpy.ndarray """ - if self.use_segmentation_mask: - return self._segment_mask_with_segmentation(center, subgraph) - return self._segment_mask_skeleton_only(subgraph) - - def _segment_mask_skeleton_only(self, subgraph): mask = np.zeros(self.patch_shape) - center = subgraph.node_voxel(0) offset = img_util.get_offset(center, self.patch_shape) for node1, node2 in subgraph.edges: v1 = subgraph.node_local_voxel(node1, offset) @@ -268,33 +270,27 @@ def _segment_mask_skeleton_only(self, subgraph): ) return mask - def _segment_mask_with_segmentation(self, center, subgraph): - mask = self.segmentation_reader.read(center, self.patch_shape) - mask = img_util.remove_small_segments(mask, 1000) - mask = 0.5 * (mask > 0).astype(float) - offset = img_util.get_offset(center, self.patch_shape) - for node1, node2 in subgraph.edges: - v1 = subgraph.node_local_voxel(node1, offset) - v2 = subgraph.node_local_voxel(node2, offset) - img_util.annotate_voxels( - mask, geometry_util.make_digital_line(v1, v2) - ) - return mask - - # --- Private helpers --- - def _list_indices(self): - # Set idxs - pos_idxs = np.arange(len(self.merge_sites)) - neg_idxs = np.arange(len(self.nonmerge_sites)) + # --- Helpers --- + def add_nonmerge_sites(self, num_sites): + # Generate sites + new_sites = list() + for _ in range(num_sites): + node, _ = self.get_random_nonmerge_site() + site = { + "node": node, + "xyz": self.graph.node_xyz[node], + "filename": "random" + } + new_sites.append(site) - # Check for class imbalance - if len(neg_idxs) < len(pos_idxs): - neg_idxs = -pos_idxs - else: - neg_idxs = -np.random.choice( - neg_idxs, size=len(pos_idxs), replace=False + # Add sites to existing + if len(self.nonmerge_sites) > 0: + df = pd.DataFrame(new_sites) + self.nonmerge_sites = pd.concat( + [df, self.nonmerge_sites], ignore_index=True ) - return np.concatenate((pos_idxs, neg_idxs)) + else: + self.nonmerge_sites = pd.DataFrame(new_sites) def _has_nearby_branching(self, root, max_depth=60): queue = [(root, 0)] @@ -313,13 +309,31 @@ def _has_nearby_branching(self, root, max_depth=60): visited.add(j) return False + def _list_indices(self): + # Set idxs + pos_idxs = np.arange(len(self.merge_sites)) + neg_idxs = np.arange(len(self.nonmerge_sites)) + + # Check for class imbalance + if len(neg_idxs) < len(pos_idxs): + neg_idxs = -pos_idxs + else: + neg_idxs = -np.random.choice( + neg_idxs, size=len(pos_idxs), replace=False + ) + return np.concatenate((pos_idxs, neg_idxs)) + def __len__(self): return len(self._list_indices()) - -# --------------------------------------------------------------------------- -# BrainDatasetCollection -# --------------------------------------------------------------------------- + def __repr__(self): + return ( + f"BrainDataset(" + f"brain_id={self.brain_id}, " + f"n_examples={len(self)}, " + f"n_pos_examples={len(self.merge_sites)}, " + f"n_neg_examples={len(self.nonmerge_sites)})" + ) class BrainDatasetCollection(Dataset): @@ -392,7 +406,6 @@ def get_idxs(self): return np.arange(len(self._index_table)) # --- Helpers --- - def brain_ids(self): """Returns the list of brain IDs in this collection.""" return [bd.brain_id for bd in self.brain_datasets] @@ -417,11 +430,7 @@ def __repr__(self): ) -# --------------------------------------------------------------------------- -# MergeSiteDataLoader -# --------------------------------------------------------------------------- - - +# --- Dataloader --- class ThreadedDataLoader(DataLoader): _VALID_MODALITIES = {None, "graph", "pointcloud"} @@ -469,7 +478,7 @@ def __iter__(self): np.random.shuffle(idxs) # Split into batches upfront - batch_index_groups = [ + batch_idx_groups = [ idxs[start: min(start + self.batch_size, len(idxs))] for start in range(0, len(idxs), self.batch_size) ] @@ -480,7 +489,7 @@ def __iter__(self): def prefetch_worker(): try: - for batch_idxs in batch_index_groups: + for batch_idxs in batch_idx_groups: buffer.put(self._load_batch(batch_idxs)) except Exception as e: buffer.put(e) @@ -608,9 +617,7 @@ def _load_image_graph_batch(self, idxs): h.append(h_i) x.append(x_i) edge_index.append(edge_index_i) - batches.append( - torch.full((n_i,), i, dtype=torch.long) - ) + batches.append(torch.full((n_i,), i, dtype=torch.long)) node_offset += n_i @@ -624,7 +631,55 @@ def _load_image_graph_batch(self, idxs): batch = ml_util.TensorDict( { "img": ml_util.to_tensor(patches), - "graph": (h, x, edge_index, batches) + "graph": (h, x, edge_index, batches), } ) return batch, ml_util.to_tensor(targets) + + +# --- Sites Loading --- +def create_dataset_collection( + img_prefixes_path, + root_path, + anisotropy=(1.0, 1.0, 1.0), + brightness_clip=500, + subgraph_radius=100, + node_spacing=5, + patch_shape=(128, 128, 128), + probability_random_nonmerge_site=0.5, + use_transform=False, +): + # Load image prefixes + bucket, root_prefix = util.parse_cloud_path(root_path) + img_prefixes = util.read_json(img_prefixes_path) + + # Iterate over brains + datasets = list() + for subprefix in util.list_gcs_subprefixes(bucket, root_prefix): + # Extract dataset info + brain_id = subprefix.split("/")[-2] + segmentation_id = get_segmentation_id(bucket, subprefix) + dataset_path = os.path.join(root_path, brain_id, segmentation_id) + img_path = os.path.join(img_prefixes[brain_id], "0") + + # Add dataset + dataset = BrainDataset( + brain_id, + dataset_path, + img_path, + anisotropy=anisotropy, + brightness_clip=brightness_clip, + subgraph_radius=subgraph_radius, + node_spacing=node_spacing, + patch_shape=patch_shape, + probability_random_nonmerge_site=probability_random_nonmerge_site, + use_transform=use_transform, + ) + datasets.append(dataset) + return BrainDatasetCollection(datasets) + + +def get_segmentation_id(bucket, prefix): + subprefixes = util.list_gcs_subprefixes(bucket, prefix) + assert len(subprefixes) == 1 + return subprefixes[0].split("/")[-2] diff --git a/src/neuron_proofreader/utils/swc_util.py b/src/neuron_proofreader/utils/swc_util.py index 6300fd53..659653a9 100644 --- a/src/neuron_proofreader/utils/swc_util.py +++ b/src/neuron_proofreader/utils/swc_util.py @@ -294,8 +294,8 @@ def read_from_cloud(self, path): elif zip_paths: read_fn = self.read_s3_zip if use_s3 else self.read_gcs_zip return self.read_zips(zip_paths, read_fn) - - raise Exception(f"SWC Pointer is invalid {path}") + else: + return list() def read_gcs_swc(self, path): """ diff --git a/src/neuron_proofreader/utils/util.py b/src/neuron_proofreader/utils/util.py index 5cf509b8..e9cdc398 100644 --- a/src/neuron_proofreader/utils/util.py +++ b/src/neuron_proofreader/utils/util.py @@ -451,16 +451,16 @@ def list_gcs_paths(bucket_name, prefix, extension=""): return paths -def list_gcs_subdirectories(bucket_name, prefix): +def list_gcs_subprefixes(bucket_name, prefix): """ Lists all direct subdirectories of a given prefix in a GCS bucket. Parameters ---------- - bucket : str - Name of bucket to be read from. + bucket_name : str + Name of bucket containing prefix. prefix : str - Path to directory in "bucket". + Path to location within bucket to be searched. Returns ------- @@ -468,8 +468,7 @@ def list_gcs_subdirectories(bucket_name, prefix): Direct subdirectories. """ # Load blobs - storage_client = storage.Client() - blobs = storage_client.list_blobs( + blobs = storage.Client().list_blobs( bucket_name, prefix=prefix, delimiter="/" ) [blob.name for blob in blobs] From a1712597f48d9bdbf3157eefd7190db17c44b269 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 2 Jun 2026 18:33:09 +0000 Subject: [PATCH 03/11] rm exaspim dataloader --- .../machine_learning/exaspim_dataloader.py | 526 ------------------ .../machine_learning/image_dataloader.py | 0 2 files changed, 526 deletions(-) delete mode 100644 src/neuron_proofreader/machine_learning/exaspim_dataloader.py create mode 100644 src/neuron_proofreader/machine_learning/image_dataloader.py diff --git a/src/neuron_proofreader/machine_learning/exaspim_dataloader.py b/src/neuron_proofreader/machine_learning/exaspim_dataloader.py deleted file mode 100644 index 7724dc86..00000000 --- a/src/neuron_proofreader/machine_learning/exaspim_dataloader.py +++ /dev/null @@ -1,526 +0,0 @@ -""" -Created on Jan 26 5:00:00 2026 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Routines for loading image patches from whole-brain exaSPIM datasets. - -""" - -from concurrent.futures import ( - as_completed, - ProcessPoolExecutor, - ThreadPoolExecutor, -) -from torch.utils.data import IterableDataset - -import fastremap -import numpy as np -import random -import torch - -from neuron_proofreader.utils.img_util import TensorStoreReader -from neuron_proofreader.utils import swc_util, util - - -# --- Custom Datasets --- -class ExaspimDataset(IterableDataset): - """ - A PyTorch Dataset for sampling 3D patches from whole-brain images. The - dataset's __getitem__ method returns both raw image and segmentation - patches. Optionally, the patch sampling maybe biased towards foreground - regions. - """ - - def __init__( - self, - patch_shape, - brightness_clip=400, - boundary_buffer=5000, - foreground_sampling_rate=0.5, - n_examples_per_epoch=1000, - normalization_percentiles=(1, 99.9), - prefetch_foreground_sampling=32, - ): - """ - Instantiates an ExaspimDataset object. - - Parameters - ---------- - patch_shape : Tuple[int] - Shape of image patches to be read from image and segmentation. - brightness_clip : int, optional - Brightness intensity used as upper limit of image patch. Default - is 400. - boundary_buffer : int, optional - Image patches are sampled at least "boundary_buffer" voxels away - from boundary along each dimension. Default is 5000. - foreground_sampling_rate : float, optional - Rate at which image patches containing foreground objects are - sampled. Default is 0.5. - n_examples_per_epoch : int, optional - Number of examples generated for each epoch. Default is 1000. - normalization_percentiles : Tuple[float], optional - Upper and lower bounds of percentiles used to normalize image. - Default is (1, 99.9). - prefetch_foreground_sampling : int, optional - Number of image patches that are preloaded during foreground - search in "self.sample_segmentation_voxel" and - "self.sample_bright_voxel". Default is 32. - """ - # Call parent class - super(ExaspimDataset, self).__init__() - - # Class attributes - self.boundary_buffer = boundary_buffer - self.brightness_clip = brightness_clip - self.foreground_sampling_rate = foreground_sampling_rate - self.n_examples_per_epoch = n_examples_per_epoch - self.normalization_percentiles = normalization_percentiles - self.patch_shape = patch_shape - self.prefetch_foreground_sampling = prefetch_foreground_sampling - self.swc_reader = swc_util.Reader() - - # Data structures - self.imgs = dict() - self.segmentations = dict() - self.skeletons = dict() - - # --- Ingest Data --- - def ingest_brain( - self, brain_id, img_path, segmentation_path=None, swc_path=None - ): - """ - Loads a brain image, label mask, and skeletons, then stores each in - internal dictionaries. - - Parameters - ---------- - brain_id : str - Unique identifier for the brain corresponding to the image. - img_path : str - Path to whole-brain image to be read. - segmentation_path : str, optional - Path to segmentation. Default is None. - swc_path : str, optional - Path to SWC files. Default is None. - """ - # Load data - self.imgs[brain_id] = TensorStoreReader(img_path) - self._load_segmentation(brain_id, segmentation_path) - self._load_swcs(brain_id, swc_path) - - # Check image shapes - shape1 = self.imgs[brain_id].shape()[2::] - shape2 = self.segmentations[brain_id].shape() - assert shape1 == shape2, f"img_shape={shape1}, mask_shape={shape2}" - - def _load_segmentation(self, brain_id, path): - if path: - self.segmentations[brain_id] = TensorStoreReader(path) - - def _load_swcs(self, brain_id, swc_path): - if swc_path: - # Initializations - swc_dicts = self.swc_reader(swc_path) - n_points = np.sum([len(d["xyz"]) for d in swc_dicts]) - - # Extract skeleton voxels - if n_points > 0: - start = 0 - skeletons = np.zeros((n_points, 3), dtype=np.int32) - for swc_dict in swc_dicts: - end = start + len(swc_dict["xyz"]) - skeletons[start:end] = swc_dict["xyz"] - start = end - self.skeletons[brain_id] = skeletons[:, [2, 1, 0]] - - # --- Sample Image Patches --- - def __iter__(self): - """ - Returns a pair of noisy and BM4D-denoised image patches, normalized - according to percentile-based scaling. - - Returns - ------- - img : numpy.ndarray - Patch from raw image - mask : numpy.ndarray - Binarized mask from segmentation. - """ - for _ in range(self.n_examples_per_epoch): - # Get example - brain_id = self.sample_brain() - voxel = self.sample_voxel(brain_id) - img = self.read_image(brain_id, voxel) - mask = self.read_segmentation(brain_id, voxel) - - # Prepocess patches - img = self.preprocess_image(img) - mask = self.preprocess_mask(mask) - yield img, mask - - def sample_brain(self): - """ - Samples a brain ID from the loaded images. - - Returns - ------- - brain_id : str - Unique identifier of the sampled whole-brain. - """ - return util.sample_once(self.imgs.keys()) - - def sample_voxel(self, brain_id): - """ - Samples a voxel from a brain volume, either foreground or interior. - - Parameters - ---------- - brain_id : str - Unique identifier of the sampled whole-brain. - - Returns - ------- - Tuple[int] - Voxel coordinate chosen according to the foreground or interior - sampling strategy. - """ - if random.random() < self.foreground_sampling_rate: - return self.sample_foreground_voxel(brain_id) - else: - return self.sample_interior_voxel(brain_id) - - def sample_foreground_voxel(self, brain_id): - """ - Samples a voxel likely to be part of the foreground of a neuron. - - Parameters - ---------- - brain_id : str - Unique identifier of a whole-brain. - - Returns - ------- - Tuple[int] - Voxel coordinate representing a likely foreground location. - """ - if brain_id in self.skeletons and np.random.random() > 0.5: - return self.sample_skeleton_voxel(brain_id) - elif brain_id in self.segmentations: - return self.sample_segmentation_voxel(brain_id) - else: - return self.sample_bright_voxel(brain_id) - - def sample_interior_voxel(self, brain_id): - """ - Samples a random voxel coordinate from the interior of a 3D image - volume, avoiding boundary regions. - - Parameters - ---------- - brain_id : str - Unique identifier of a whole-brain. - - Returns - ------- - Tuple[int] - Voxel coordinate sampled uniformly at random within the valid - interior region of the image volume. - """ - voxel = list() - for s in self.imgs[brain_id].shape()[2::]: - upper = s - self.boundary_buffer - voxel.append(random.randint(self.boundary_buffer, upper)) - return tuple(voxel) - - def sample_skeleton_voxel(self, brain_id): - """ - Samples a voxel coordinate near a skeleton point. - - Parameters - ---------- - brain_id : str - Unique identifier of a whole-brain. - - Returns - ------- - Tuple[int] - Voxel coordinate near a skeleton point. - """ - idx = random.randint(0, len(self.skeletons[brain_id]) - 1) - shift = np.random.randint(0, 16, size=3) - return tuple(self.skeletons[brain_id][idx] + shift) - - def sample_segmentation_voxel(self, brain_id): - """ - Sample a voxel coordinate whose corresponding segmentation patch - contains a sufficiently large object. - - Parameters - ---------- - brain_id : str - Identifier for the image volume which must be a key in - "self.segmentations". - - Returns - ------- - best_voxel : Tuple[int] - Voxel coordinate whose patch contains a sufficiently large object - or had the largest object after 5 * self.prefetch attempts. - """ - best_volume = 0 - best_voxel = self.sample_interior_voxel(brain_id) - cnt = 0 - with ThreadPoolExecutor() as executor: - while best_volume < 1600: - # Read random image patches - pending = dict() - for _ in range(self.prefetch_foreground_sampling): - voxel = self.sample_interior_voxel(brain_id) - thread = executor.submit( - self.read_segmentation, brain_id, voxel - ) - pending[thread] = voxel - - # Check if labels patch has large enough object - for thread in as_completed(pending.keys()): - voxel = pending.pop(thread) - labels_patch = thread.result() - vals, cnts = fastremap.unique( - labels_patch, return_counts=True - ) - - if len(cnts) > 1: - volume = np.max(cnts[1:]) - if volume > best_volume: - best_voxel = voxel - best_volume = volume - - # Check number of tries - cnt += 1 - if cnt > 5: - break - return best_voxel - - def sample_bright_voxel(self, brain_id): - """ - Samples a voxel coordinate whose image patch is sufficiently bright. - - Parameters - ---------- - brain_id : str - Unique identifier of a whole-brain. - - Returns - ------- - best_voxel : Tuple[int] - Voxel coordinate whose patch is sufficiently bright or is the - highest observed brightness after 4 * self.prefetch attempts. - """ - best_brightness = 0 - best_voxel = self.sample_interior_voxel(brain_id) - cnt = 0 - with ThreadPoolExecutor() as executor: - while best_brightness < 1000: - # Read random image patches - pending = dict() - for _ in range(self.prefetch_foreground_sampling): - voxel = self.sample_interior_voxel(brain_id) - thread = executor.submit(self.read_image, brain_id, voxel) - pending[thread] = voxel - - # Check if image patch is bright enough - for thread in as_completed(pending.keys()): - voxel = pending.pop(thread) - img_patch = thread.result() - brightness = np.sum(img_patch > 100) - if brightness > best_brightness: - best_voxel = voxel - best_brightness = brightness - - # Check number of tries - cnt += 1 - if cnt > 5: - break - return best_voxel - - # --- Helpers --- - def __len__(self): - pass - - def preprocess_image(self, img): - """ - Preprocesses the given image by clipping the intensity values and - normalizing with a percentile-based scheme. - - Parameters - ---------- - img : numpy.ndarray - Image to be normalized - - Returns - ------- - img : numpy.ndarray - Normalized image. - """ - # Clip - img = np.minimum(img, self.brightness_clip) - - # Normalize - mn, mx = np.percentile(img, self.normalization_percentiles) - img = (img - mn) / (mx - mn + 1e-8) - return np.clip(img, 0, 1) - - def preprocess_mask(self, mask): - """ - Preprocesses the given segmentation mask by binarizing it. - - Parameters - ---------- - img : numpy.ndarray - Image to be normalized - - Returns - ------- - img : numpy.ndarray - Normalized image. - """ - - return (mask > 0).astype(int) - - def read_image(self, brain_id, voxel): - """ - Reads an image patch from the given brain at the specified location. - - Parameters - ---------- - brain_id : str - Unique identifier of whole-brain dataset. - voxel : Tuple[int] - Center of image patch to be read. - - Returns - ------- - numpy.ndarray - Image patch. - """ - return self.imgs[brain_id].read(voxel, self.patch_shape) - - def read_segmentation(self, brain_id, voxel): - """ - Reads a segmentation patch from the given brain at the specified' - location. - - Parameters - ---------- - brain_id : str - Unique identifier of whole-brain dataset. - voxel : Tuple[int] - Center of image patch to be read. - - Returns - ------- - numpy.ndarray - Segmentation patch. - """ - return self.segmentations[brain_id].read(voxel, self.patch_shape) - - -# --- Custom Dataloader --- -class DataLoader: - """ - DataLoader that uses multithreading to fetch image patches from the cloud - to form batches. - - Attributes - ---------- - dataset : torch.utils.data.Dataset - Dataset to iterated over. - batch_size : int - Number of examples in each batch. - patch_shape : Tuple[int] - Shape of image patch expected by the model. - """ - - def __init__(self, dataset, batch_size=16): - """ - Instantiates a DataLoader object. - - Parameters - ---------- - dataset : torch.utils.data.Dataset - Dataset to iterated over. - batch_size : int, optional - Number of examples in each batch. Default is 16. - """ - # Instance attributes - self.dataset = dataset - self.batch_size = batch_size - self.patch_shape = dataset.patch_shape - - def __iter__(self): - """ - Iterates over the dataset and yields batches of examples. - - Returns - ------- - iterator - Yields batches of examples. - """ - for idx in range(0, len(self.dataset), self.batch_size): - yield self._load_batch(idx) - - def _load_batch(self, start_idx): - # Compute batch size - n_remaining_examples = len(self.dataset) - start_idx - batch_size = min(self.batch_size, n_remaining_examples) - - # Generate batch - with ProcessPoolExecutor() as executor: - # Assign processs - processes = list() - for idx in range(start_idx, start_idx + batch_size): - processes.append( - executor.submit(self.dataset.__getitem__, idx) - ) - - # Process results - img_patches = np.zeros( - ( - batch_size, - 1, - ) - + self.patch_shape - ) - mask_patches = np.zeros( - ( - batch_size, - 1, - ) - + self.patch_shape - ) - for i, process in enumerate(as_completed(processes)): - img, mask = process.result() - img_patches[i, 0, ...] = img - mask_patches[i, 0, ...] = mask - return to_tensor(img_patches), to_tensor(mask_patches) - - -# --- Helpers --- -def to_tensor(arr): - """ - Converts the given NumPy array to a torch tensor. - - Parameters - ---------- - arr : numpy.ndarray - Array to be converted. - - Returns - ------- - torch.Tensor - Array converted to a torch tensor. - """ - return torch.tensor(arr, dtype=torch.float) diff --git a/src/neuron_proofreader/machine_learning/image_dataloader.py b/src/neuron_proofreader/machine_learning/image_dataloader.py new file mode 100644 index 00000000..e69de29b From 8436e9835545a88a1014dc7aa59677bb78fc4cd2 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Wed, 3 Jun 2026 03:30:20 +0000 Subject: [PATCH 04/11] refactor: image loader working for merge detection --- ...{augmentation.py => image_augmentation.py} | 0 .../machine_learning/image_dataloader.py | 281 ++++++++++++++++++ .../merge_datamodules_v2.py | 108 ++----- .../merge_feature_generation.py | 0 .../split_feature_extraction.py | 6 +- src/neuron_proofreader/utils/img_util.py | 140 +-------- src/neuron_proofreader/utils/util.py | 54 ---- 7 files changed, 313 insertions(+), 276 deletions(-) rename src/neuron_proofreader/machine_learning/{augmentation.py => image_augmentation.py} (100%) delete mode 100644 src/neuron_proofreader/merge_proofreading/merge_feature_generation.py diff --git a/src/neuron_proofreader/machine_learning/augmentation.py b/src/neuron_proofreader/machine_learning/image_augmentation.py similarity index 100% rename from src/neuron_proofreader/machine_learning/augmentation.py rename to src/neuron_proofreader/machine_learning/image_augmentation.py diff --git a/src/neuron_proofreader/machine_learning/image_dataloader.py b/src/neuron_proofreader/machine_learning/image_dataloader.py index e69de29b..b13aac5c 100644 --- a/src/neuron_proofreader/machine_learning/image_dataloader.py +++ b/src/neuron_proofreader/machine_learning/image_dataloader.py @@ -0,0 +1,281 @@ +""" +Created on Tue Jan 13 15:00:00 2026 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Code for parallelizing reading image patches from the cloud. + +""" + +from abc import ABC, abstractmethod + +import numpy as np +import tensorstore as ts + +from neuron_proofreader.machine_learning.image_augmentation import ( + ImageTransforms +) +from neuron_proofreader.utils import geometry_util, img_util, util + + +# --- Image Reading --- +class TensorStoreImage: + """ + Class that reads images with the TensorStore library. + """ + + def __init__(self, img_path): + """ + Instantiates a TensorStoreImage object. + + Parameters + ---------- + img_path : str + Path to image. + """ + # Load image + bucket_name, inner_path = util.parse_cloud_path(img_path) + self.img = ts.open( + { + "driver": img_util.get_driver(img_path), + "kvstore": { + "driver": img_util.get_storage_driver(img_path), + "bucket": bucket_name, + "path": inner_path, + }, + "context": { + "cache_pool": {"total_bytes_limit": 1000000000}, + "cache_pool#remote": {"total_bytes_limit": 1000000000}, + "data_copy_concurrency": {"limit": 8}, + }, + "recheck_cached_data": "open", + } + ).result() + + # Check for Google segmentation + if "from_google" in img_path: + self.img = self.img[ts.d[:].transpose[3, 2, 1, 0]] + + # Check dimensions + while self.img.ndim < 5: + self.img = self.img[ts.newaxis, ...] + + def read(self, voxel, shape): + """ + Reads a patch from an image given a voxel coordinate and patch shape. + + Parameters + ---------- + voxel : Tuple[int] + Center of image patch to be read. + shape : Tuple[int] + Shape of image patch to be read. + + Returns + ------- + numpy.ndarray + Image patch. + """ + s = img_util.get_slices(voxel, shape) + return self.img[(0, 0, *s)].read().result() + + def shape(self): + """ + Gets the shape of image. + + Returns + ------- + Tuple[int] + Shape of image. + """ + return self.img.shape + + +# --- Patch Loading --- +class PatchLoader(ABC): + """ + A class for reading image patches and generating segment masks. + """ + + max_voxel_shift = 5 + + def __init__( + self, + graph, + img_path, + brightness_clip=400, + normalization_percentiles=(1, 99.5), + patch_shape=(128, 128, 128), + use_transform=False, + ): + """ + Instantiates a PatchLoader object. + + Parameters + ---------- + graph : SkeletonGraph + Graph used to compute patch voxel coordinates. + img_path : str + Path to whole-brain image. + brightness_clip : int, optional + Intensity value that voxel brightnesses are clipped to. + normalization_percentiles : Tuple[float], optional + Percentiles used to normalize patches. Default is (1, 99.5). + patch_shape : Tuple[int], optional + Shape of patch to be read from image. Default is (128, 128, 128). + """ + # Instance attributes + self.brightness_clip = brightness_clip + self.graph = graph + self.patch_shape = patch_shape + self.percentiles = normalization_percentiles + + # Image operations + self.img = TensorStoreImage(img_path) + self.transform = ImageTransforms() if use_transform else None + + # --- Abstract Interface --- + @abstractmethod + def __call__(self): + """ + Abstract method to be implemented by subclasses + """ + pass + + @abstractmethod + def compute_patch_specs(self): + """ + Abstract method to be implemented by subclasses. + """ + pass + + @abstractmethod + def create_mask(self): + """ + Abstract method to be implemented by subclasses. + """ + pass + + # --- Core Routines --- + def annotate_foreground(self, mask, nodes, offset, fill_val=1): + visited = set() + for i in nodes: + voxel_i = self.graph.node_local_voxel(i, offset) + for j in self.graph.neighbors(i): + if frozenset({i, j}) not in visited and j in nodes: + voxel_j = self.graph.node_local_voxel(j, offset) + voxels = geometry_util.make_digital_line(voxel_i, voxel_j) + img_util.annotate_voxels(mask, voxels, fill_val=fill_val) + visited.add(frozenset({i, j})) + + def annotate_fragment(self, mask, subgraph, offset, fill_val=1): + for node1, node2 in subgraph.edges: + # Get local voxel coordinates + voxel1 = subgraph.node_local_voxel(node1, offset) + voxel2 = subgraph.node_local_voxel(node2, offset) + + # Populate mask + voxels = geometry_util.make_digital_line(voxel1, voxel2) + img_util.annotate_voxels(mask, voxels, fill_val=fill_val) + + def read_image(self, center, shape): + """ + Reads the image patch specified by the given center and shape. + + Parameters + ---------- + center : Tuple[int] + Center of image patch to be read. + shape : Tuple[int] + Center of image patch to be read. + + Returns + ------- + patch : numpy.ndarray + Preprocessed image patch. + """ + patch = self.img.read(center, shape) + patch = np.minimum(patch, self.brightness_clip) + patch = img_util.normalize(patch, percentiles=self.percentiles) + return patch + + # --- Helpers --- + def adjust_voxel(self, voxel): + if self.transform: + voxel += np.random.randint( + -self.max_voxel_shift, self.max_voxel_shift + 1, size=3 + ) + return voxel + + @staticmethod + def stack(img, mask): + try: + patches = np.stack([img, mask], axis=0) + except ValueError: + img = img_util.pad_to_shape(img, mask.shape) + patches = np.stack([img, mask], axis=0) + return patches + + +class DetectionPatchLoader(PatchLoader): + + def __init__( + self, + graph, + img_path, + brightness_clip=400, + normalization_percentiles=(1, 99.5), + patch_shape=(128, 128, 128), + use_transform=False, + ): + # Call parent class + super().__init__( + graph, + img_path, + brightness_clip=brightness_clip, + normalization_percentiles=normalization_percentiles, + patch_shape=patch_shape, + use_transform=use_transform, + ) + + # --- Implementation of Abstract Inferface --- + def __call__(self, node): + # Get patches + center, shape = self.compute_patch_specs(node) + img = self.read_image(center, shape) + mask = self.create_mask(center, shape, node) + patches = self.stack(img, mask) + + # Check whether to apply image augmentation + if self.transform: + patches = self.transform(patches) + return patches + + def compute_patch_specs(self, node): + voxel = self.graph.node_voxel(node) + voxel = self.adjust_voxel(voxel) + return voxel, self.patch_shape + + def create_mask(self, center, shape, node): + # Initializations + offset = img_util.get_offset(center, shape) + depth = np.sqrt(2) * np.max(shape) / (2 * self.graph.anisotropy.min()) + nodes = self.get_foreground_nodes(node, depth) + subgraph = self.graph.rooted_subgraph(node, depth) + + # Annotate mask + mask = np.zeros(shape) + self.annotate_foreground(mask, nodes, offset, fill_val=0.5) + self.annotate_fragment(mask, subgraph, offset, fill_val=1) + return mask + + # --- Helpers --- + def get_foreground_nodes(self, node, radius): + xyz = self.graph.node_xyz[node] + nodes = self.graph.kdtree.query_ball_point(xyz, radius) + return nodes + + +class ProposalPatchLoader(PatchLoader): + pass diff --git a/src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py b/src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py index ea6d25a3..54e3861f 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py @@ -38,7 +38,9 @@ import threading import torch -from neuron_proofreader.machine_learning.augmentation import ImageTransforms +from neuron_proofreader.machine_learning.image_dataloader import ( + DetectionPatchLoader, +) from neuron_proofreader.machine_learning.geometric_gnn_models import ( subgraph_to_data, ) @@ -46,13 +48,7 @@ subgraph_to_point_cloud, ) from neuron_proofreader.skeleton_graph import SkeletonGraph -from neuron_proofreader.utils import ( - geometry_util, - img_util, - ml_util, - swc_util, - util, -) +from neuron_proofreader.utils import ml_util, swc_util, util # -- Datasets --- @@ -87,8 +83,9 @@ def __init__( brightness_clip=500, subgraph_radius=100, node_spacing=5, + normalization_percentiles=(1, 99.5), patch_shape=(128, 128, 128), - probability_random_nonmerge_site=0.5, + random_nonmerge_site_prob=0.5, use_transform=False, ): # Instance attributes @@ -98,28 +95,33 @@ def __init__( self.ignore_fragments = set() self.node_spacing = node_spacing self.patch_shape = patch_shape - self.probability_random_nonmerge_site = ( - probability_random_nonmerge_site - ) + self.random_nonmerge_site_prob = random_nonmerge_site_prob self.subgraph_radius = subgraph_radius # Core data structures self.graph = self.load_fragments(sites_prefix) - self.img = img_util.TensorStoreImage(img_path) - self.set_giant_components() - self.merge_sites = self.load_sites( - os.path.join(sites_prefix, "merge_sites/") + os.path.join(sites_prefix, "merge_sites") ) self.nonmerge_sites = self.load_sites( - os.path.join(sites_prefix, "nonmerge_sites/") + os.path.join(sites_prefix, "nonmerge_sites") ) + + # Image loading + self.patch_loader = DetectionPatchLoader( + self.graph, + img_path, + brightness_clip=brightness_clip, + normalization_percentiles=normalization_percentiles, + patch_shape=patch_shape, + use_transform=use_transform, + ) + + # Store dataset info + self.set_giant_components() self.set_merge_site_info() self.set_valid_branching_nodes() - # Image augmentation for training - self.transform = ImageTransforms() if use_transform else None - def load_fragments(self, sites_prefix): graph = SkeletonGraph( anisotropy=self.anisotropy, @@ -189,31 +191,15 @@ def set_valid_branching_nodes(self): # --- Site Retrieval --- def __getitem__(self, idx): - # Get example node, label = self.get_site(idx) subgraph = self.graph.rooted_subgraph(node, self.subgraph_radius) - - # Get voxel coordinate - voxel = subgraph.node_voxel(0) - if self.transform: - voxel += np.random.randint(-6, 6 + 1, size=3) - - # Extract subgraph and image patches centered at site - img_patch = self.get_img_patch(voxel) - segment_mask = self.get_segment_mask(voxel, subgraph) - - # Stack image channels - try: - patches = np.stack([img_patch, segment_mask], axis=0) - except ValueError: - img_patch = img_util.pad_to_shape(img_patch, self.patch_shape) - patches = np.stack([img_patch, segment_mask], axis=0) + patches = self.patch_loader(node) return patches, subgraph, label def get_site(self, idx): if idx > 0: return self.merge_sites["node"][idx], 1 - elif np.random.random() < self.probability_random_nonmerge_site: + elif np.random.random() < self.random_nonmerge_site_prob: return self.get_random_nonmerge_site() elif abs(idx) < len(self.nonmerge_sites): return self.nonmerge_sites["node"][abs(idx)], 0 @@ -228,48 +214,6 @@ def get_random_nonmerge_site(self): node = util.sample_once(self.graph.nodes) return node, 0 - # --- Image / Mask Extraction --- - def get_img_patch(self, center): - """ - Extracts, clips, and normalises a 3D image patch centred at center. - - Parameters - ---------- - center : numpy.ndarray - Voxel coordinates of the patch centre. - - Returns - ------- - numpy.ndarray - """ - patch = self.img.read(center, self.patch_shape) - patch = np.minimum(patch, self.brightness_clip) - return img_util.normalize(patch) - - def get_segment_mask(self, center, subgraph): - """ - Builds the segment mask for subgraph. - - Parameters - ---------- - center : numpy.ndarray - Voxel coordinates of the patch centre. - subgraph : SkeletonGraph - - Returns - ------- - numpy.ndarray - """ - mask = np.zeros(self.patch_shape) - offset = img_util.get_offset(center, self.patch_shape) - for node1, node2 in subgraph.edges: - v1 = subgraph.node_local_voxel(node1, offset) - v2 = subgraph.node_local_voxel(node2, offset) - img_util.annotate_voxels( - mask, geometry_util.make_digital_line(v1, v2) - ) - return mask - # --- Helpers --- def add_nonmerge_sites(self, num_sites): # Generate sites @@ -646,7 +590,7 @@ def create_dataset_collection( subgraph_radius=100, node_spacing=5, patch_shape=(128, 128, 128), - probability_random_nonmerge_site=0.5, + random_nonmerge_site_prob=0.5, use_transform=False, ): # Load image prefixes @@ -672,7 +616,7 @@ def create_dataset_collection( subgraph_radius=subgraph_radius, node_spacing=node_spacing, patch_shape=patch_shape, - probability_random_nonmerge_site=probability_random_nonmerge_site, + random_nonmerge_site_prob=random_nonmerge_site_prob, use_transform=use_transform, ) datasets.append(dataset) diff --git a/src/neuron_proofreader/merge_proofreading/merge_feature_generation.py b/src/neuron_proofreader/merge_proofreading/merge_feature_generation.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py index 2b6b81cf..7976898e 100644 --- a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py +++ b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py @@ -326,7 +326,7 @@ def create_segment_mask(self, proposal, shape, offset): if frozenset({i, j}) not in visited and j in nodes: voxel_j = self.graph.node_local_voxel(j, offset) voxels = geometry_util.make_digital_line(voxel_i, voxel_j) - img_util.annotate_voxels(mask, voxels, val=0.25) + img_util.annotate_voxels(mask, voxels, fill_val=0.25) visited.add(frozenset({i, j})) return mask @@ -503,14 +503,14 @@ def annotate_edge(self, node): node : int Node ID used to get branch to be annotated. """ - img_util.annotate_voxels(self.mask, self.voxels[node], val=0.5) + img_util.annotate_voxels(self.mask, self.voxels[node], fill_val=0.5) def annotate_proposal(self): """ Annotates the proposal within the given mask. """ voxels = self.get_profile_line() - img_util.annotate_voxels(self.mask, voxels, val=1) + img_util.annotate_voxels(self.mask, voxels, fill_val=1) def get_profile_line(self, n_pts=None): """ diff --git a/src/neuron_proofreader/utils/img_util.py b/src/neuron_proofreader/utils/img_util.py index 04d74daf..dfc8a82b 100644 --- a/src/neuron_proofreader/utils/img_util.py +++ b/src/neuron_proofreader/utils/img_util.py @@ -20,77 +20,6 @@ from neuron_proofreader.utils import util -class TensorStoreImage: - """ - Class that reads images with the TensorStore library. - """ - - def __init__(self, img_path): - """ - Instantiates a TensorStoreImage object. - - Parameters - ---------- - img_path : str - Path to image. - """ - # Load image - bucket_name, inner_path = util.parse_cloud_path(img_path) - self.img = ts.open( - { - "driver": get_driver(img_path), - "kvstore": { - "driver": get_storage_driver(img_path), - "bucket": bucket_name, - "path": inner_path, - }, - "context": { - "cache_pool": {"total_bytes_limit": 1000000000}, - "cache_pool#remote": {"total_bytes_limit": 1000000000}, - "data_copy_concurrency": {"limit": 8}, - }, - "recheck_cached_data": "open", - } - ).result() - - # Check for Google segmentation - if "from_google" in img_path: - self.img = self.img[ts.d[:].transpose[3, 2, 1, 0]] - - # Check dimensions - while self.img.ndim < 5: - self.img = self.img[ts.newaxis, ...] - - def read(self, voxel, shape): - """ - Reads a patch from an image given a voxel coordinate and patch shape. - - Parameters - ---------- - voxel : Tuple[int] - Center of image patch to be read. - shape : Tuple[int] - Shape of image patch to be read. - - Returns - ------- - numpy.ndarray - Image patch. - """ - return self.img[(0, 0, *get_slices(voxel, shape))].read().result() - - def shape(self): - """ - Gets the shape of image. - - Returns - ------- - Tuple[int] - Shape of image. - """ - return self.img.shape - - # --- Visualization --- def make_segmentation_colormap(mask, seed=42): """ @@ -218,7 +147,7 @@ def plot_segmentation_mips(segmentation): # --- Helpers --- -def annotate_voxels(img, voxels, kernel_size=3, val=1): +def annotate_voxels(img, voxels, kernel_size=3, fill_val=1): """ Annotates voxel coordinates in a 3D image by filling a patch around each voxel with a given value. @@ -231,7 +160,7 @@ def annotate_voxels(img, voxels, kernel_size=3, val=1): Voxel coordinates to annotate. kernel_size : int, optional Size of kernel used to fill around each voxel. Default is 3. - val : int, optional + fill_val : int, optional Fill value. Default is 1. """ buffer = (kernel_size - 1) // 2 @@ -239,7 +168,7 @@ def annotate_voxels(img, voxels, kernel_size=3, val=1): for voxel in voxels: if is_contained(voxel, img.shape, buffer=buffer): s = get_slices(voxel, shape) - img[s] = val + img[s] = fill_val def compute_iou3d(c1, c2, s1, s2): @@ -346,44 +275,6 @@ def get_driver(img_path): raise Exception(f"Invalid image path at {img_path}") -def get_neighbors(voxel, shape): - """ - Gets the neighbors of a given voxel coordinate. - - Parameters - ---------- - voxel : Tuple[int] - Voxel coordinate in a 3D image. - shape : Tuple[int] - Shape of the 3D image that voxel is contained within. - - Returns - ------- - neighbors : List[Tuple[int]] - Voxel coordinates of the 26 neighbors of the given voxel. - """ - # Initializations - x, y, z = voxel - depth, height, width = shape - - # Iterate over the possible offsets for x, y, and z - neighbors = [] - for dx in [-1, 0, 1]: - for dy in [-1, 0, 1]: - for dz in [-1, 0, 1]: - # Skip the (0, 0, 0) offset - if dx == 0 and dy == 0 and dz == 0: - continue - - # Calculate the neighbor's coordinates - nx, ny, nz = x + dx, y + dy, z + dz - - # Check if the neighbor is within the bounds of the 3D image - if 0 <= nx < depth and 0 <= ny < height and 0 <= nz < width: - neighbors.append((nx, ny, nz)) - return neighbors - - def get_offset(center, shape): """ Computes the spatial offset of a crop given its center and shape. @@ -582,31 +473,6 @@ def pad_to_shape(img, target_shape, pad_value=0): return np.pad(img, pads, mode="constant", constant_values=pad_value) -def remove_small_segments(segmentation, min_size): - """ - Removes small segments from a segmentation. - - Parameters - ---------- - segmentation : numpy.ndarray - Integer array representing a segmentation mask. Each unique - nonzero value corresponds to a distinct segment. - min_size : int - Minimum size (in voxels) for a segment to be kept. - - Returns - ------- - segmentation : numpy.ndarray - New segmentation of the same shape as the input, with only the - retained segments renumbered contiguously. - """ - ids, cnts = unique(segmentation, return_counts=True) - ids = [i for i, cnt in zip(ids, cnts) if cnt > min_size and i != 0] - ids = mask_except(segmentation, ids) - segmentation, _ = renumber(ids, preserve_zero=True, in_place=True) - return segmentation - - def resize(img, new_shape): """ Resizes a 3D image to the new shape using linear interpolation. diff --git a/src/neuron_proofreader/utils/util.py b/src/neuron_proofreader/utils/util.py index ca358f18..efb8668c 100644 --- a/src/neuron_proofreader/utils/util.py +++ b/src/neuron_proofreader/utils/util.py @@ -597,60 +597,6 @@ def read_s3_txt(path): return obj["Body"].read().decode("utf-8") -def list_s3_paths(bucket_name, prefix, extension=""): - """ - Lists all object keys in a public S3 bucket under a given prefix, - optionally filters by file extension. - - Parameters - ---------- - bucket_name : str - Name of the S3 bucket. - prefix : str - Prefix to search under. - extension : str, optional - File extension to filter by. Default is an empty string. - - Returns - ------- - paths : List[str] - S3 object keys that match the prefix and extension filter. - """ - # Create an anonymous client for public buckets - s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED)) - response = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix) - - # List all objects under the prefix - paths = list() - if "Contents" in response: - for obj in response["Contents"]: - filename = obj["Key"] - if filename.endswith(extension): - path = os.path.join(f"s3://{bucket_name}", filename) - paths.append(path) - return paths - - -def read_s3_txt(path): - """ - Reads a txt file stored in an S3 bucket. - - Parameters - ---------- - path : str - Path to txt file to be read. - - Returns - ------- - str - Contents of txt file. - """ - bucket_name, subpath = parse_cloud_path(path) - s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED)) - obj = s3.get_object(Bucket=bucket_name, Key=subpath) - return obj["Body"].read().decode("utf-8") - - def upload_dir_to_s3(dir_path, bucket_name, prefix): """ Uploads a directory on the local machine to an S3 bucket. From efbb6404260259a2905a6eb70258c6f0e6b56be0 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Wed, 3 Jun 2026 06:51:28 +0000 Subject: [PATCH 05/11] refactor: image config --- src/neuron_proofreader/config.py | 141 +++++++++++------- .../machine_learning/image_dataloader.py | 48 +----- .../merge_datamodules_v2.py | 81 +++++----- .../merge_proofreading/merge_inference.py | 2 +- src/neuron_proofreader/proposal_graph.py | 6 +- src/neuron_proofreader/utils/img_util.py | 1 - src/neuron_proofreader/utils/swc_util.py | 2 +- src/neuron_proofreader/utils/util.py | 78 +++++++--- 8 files changed, 188 insertions(+), 171 deletions(-) diff --git a/src/neuron_proofreader/config.py b/src/neuron_proofreader/config.py index 4cf9ffb6..eb8932b2 100644 --- a/src/neuron_proofreader/config.py +++ b/src/neuron_proofreader/config.py @@ -17,6 +17,92 @@ from neuron_proofreader.utils import util +@dataclass +class GraphConfig: + pass + + +@dataclass +class ImageConfig: + """ + Configuration class for image processing parameters. + + Attributes + ---------- + brightness_clip : int + Intensity value that voxel brightness is clipped to. + percentiles : Tuple[float], optional + Percentiles used to normalize patches. + patch_shape : Tuple[int] + Shape of patch to be read from image. + transform : bool + Indication of whether to use image augmentation. + """ + brightness_clip: int = 400 + percentiles: tuple = (1, 99.5) + patch_shape: tuple = (128, 128, 128) + transform: bool = False + + +@dataclass +class MLConfig: + """ + Configuration class for machine learning model parameters. + + Attributes + ---------- + batch_size : int + The number of samples processed in one batch during training or + inference. Default is 64. + brightness_clip : int + Maximum brightness value that image intensities are clipped to. + Default is 400. + device : str + Device to load model onto. Default is "cuda". + model_name : str + Name of model used to perform inference. Default is None. + patch_shape : Tuple[int] + Shape of image patch expected by vision model. Default is (96, 96, 96). + transform : bool + Indication of whether to apply data augmentation to image patches. + Default is False. + """ + batch_size: int = 64 + brightness_clip: int = 400 + device: str = "cuda" + model_name: str = None + patch_shape: tuple = (96, 96, 96) + transform: bool = False + + def to_dict(self): + """ + Converts configuration attributes to a dictionary. + + Returns + ------- + dict + Dictionary containing configuration attributes. + """ + attributes = dict() + for k, v in vars(self).items(): + if isinstance(v, tuple): + attributes[k] = list(v) + else: + attributes[k] = v + return attributes + + def save(self, path): + """ + Saves configuration attributes to a JSON file. + """ + util.write_json(path, self.to_dict()) + + +@dataclass +class ProposalsConfig: + pass + + @dataclass class ProposalGraphConfig: """ @@ -90,61 +176,6 @@ def save(self, path): util.write_json(path, self.to_dict()) -@dataclass -class MLConfig: - """ - Configuration class for machine learning model parameters. - - Attributes - ---------- - batch_size : int - The number of samples processed in one batch during training or - inference. Default is 64. - brightness_clip : int - Maximum brightness value that image intensities are clipped to. - Default is 400. - device : str - Device to load model onto. Default is "cuda". - model_name : str - Name of model used to perform inference. Default is None. - patch_shape : Tuple[int] - Shape of image patch expected by vision model. Default is (96, 96, 96). - transform : bool - Indication of whether to apply data augmentation to image patches. - Default is False. - """ - - batch_size: int = 64 - brightness_clip: int = 400 - device: str = "cuda" - model_name: str = None - patch_shape: tuple = (96, 96, 96) - transform: bool = False - - def to_dict(self): - """ - Converts configuration attributes to a dictionary. - - Returns - ------- - dict - Dictionary containing configuration attributes. - """ - attributes = dict() - for k, v in vars(self).items(): - if isinstance(v, tuple): - attributes[k] = list(v) - else: - attributes[k] = v - return attributes - - def save(self, path): - """ - Saves configuration attributes to a JSON file. - """ - util.write_json(path, self.to_dict()) - - @dataclass class Config: """ diff --git a/src/neuron_proofreader/machine_learning/image_dataloader.py b/src/neuron_proofreader/machine_learning/image_dataloader.py index b13aac5c..ae6deba0 100644 --- a/src/neuron_proofreader/machine_learning/image_dataloader.py +++ b/src/neuron_proofreader/machine_learning/image_dataloader.py @@ -100,15 +100,7 @@ class PatchLoader(ABC): max_voxel_shift = 5 - def __init__( - self, - graph, - img_path, - brightness_clip=400, - normalization_percentiles=(1, 99.5), - patch_shape=(128, 128, 128), - use_transform=False, - ): + def __init__(self, graph, img_config, img_path): """ Instantiates a PatchLoader object. @@ -118,22 +110,12 @@ def __init__( Graph used to compute patch voxel coordinates. img_path : str Path to whole-brain image. - brightness_clip : int, optional - Intensity value that voxel brightnesses are clipped to. - normalization_percentiles : Tuple[float], optional - Percentiles used to normalize patches. Default is (1, 99.5). - patch_shape : Tuple[int], optional - Shape of patch to be read from image. Default is (128, 128, 128). + """ - # Instance attributes - self.brightness_clip = brightness_clip + self.config = img_config self.graph = graph - self.patch_shape = patch_shape - self.percentiles = normalization_percentiles - - # Image operations self.img = TensorStoreImage(img_path) - self.transform = ImageTransforms() if use_transform else None + self.transform = ImageTransforms() if img_config.transform else None # --- Abstract Interface --- @abstractmethod @@ -201,6 +183,9 @@ def read_image(self, center, shape): return patch # --- Helpers --- + def __getattr__(self, name): + return getattr(self.config, name) + def adjust_voxel(self, voxel): if self.transform: voxel += np.random.randint( @@ -220,25 +205,6 @@ def stack(img, mask): class DetectionPatchLoader(PatchLoader): - def __init__( - self, - graph, - img_path, - brightness_clip=400, - normalization_percentiles=(1, 99.5), - patch_shape=(128, 128, 128), - use_transform=False, - ): - # Call parent class - super().__init__( - graph, - img_path, - brightness_clip=brightness_clip, - normalization_percentiles=normalization_percentiles, - patch_shape=patch_shape, - use_transform=use_transform, - ) - # --- Implementation of Abstract Inferface --- def __call__(self, node): # Get patches diff --git a/src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py b/src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py index 54e3861f..21776d86 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py @@ -39,7 +39,7 @@ import torch from neuron_proofreader.machine_learning.image_dataloader import ( - DetectionPatchLoader, + DetectionPatchLoader as PatchLoader, ) from neuron_proofreader.machine_learning.geometric_gnn_models import ( subgraph_to_data, @@ -68,7 +68,7 @@ class BrainDataset: Spacing (in microns) between neighbouring graph nodes. patch_shape : Tuple[int] Shape of the 3D image patches to extract. - subgraph_radius : float + subgraph_depth : float Radius (in microns) used when extracting rooted subgraphs. """ @@ -77,59 +77,44 @@ class BrainDataset: def __init__( self, brain_id, - sites_prefix, img_path, + sites_prefix, + swcs_path, anisotropy=(1.0, 1.0, 1.0), - brightness_clip=500, - subgraph_radius=100, + img_config=None, node_spacing=5, - normalization_percentiles=(1, 99.5), - patch_shape=(128, 128, 128), random_nonmerge_site_prob=0.5, - use_transform=False, + subgraph_depth=100, ): # Instance attributes - self.anisotropy = anisotropy self.brain_id = brain_id - self.brightness_clip = brightness_clip self.ignore_fragments = set() - self.node_spacing = node_spacing - self.patch_shape = patch_shape self.random_nonmerge_site_prob = random_nonmerge_site_prob - self.subgraph_radius = subgraph_radius + self.subgraph_depth = subgraph_depth # Core data structures - self.graph = self.load_fragments(sites_prefix) + self.graph = self.load_fragments(anisotropy, node_spacing, swcs_path) self.merge_sites = self.load_sites( os.path.join(sites_prefix, "merge_sites") ) self.nonmerge_sites = self.load_sites( os.path.join(sites_prefix, "nonmerge_sites") ) - - # Image loading - self.patch_loader = DetectionPatchLoader( - self.graph, - img_path, - brightness_clip=brightness_clip, - normalization_percentiles=normalization_percentiles, - patch_shape=patch_shape, - use_transform=use_transform, - ) + self.patch_loader = PatchLoader(self.graph, img_config, img_path) # Store dataset info self.set_giant_components() self.set_merge_site_info() self.set_valid_branching_nodes() - def load_fragments(self, sites_prefix): + def load_fragments(self, anisotropy, node_spacing, swcs_path): graph = SkeletonGraph( - anisotropy=self.anisotropy, - node_spacing=self.node_spacing, + anisotropy=anisotropy, + node_spacing=node_spacing, use_anisotropy=False, verbose=True, ) - graph.load(os.path.join(sites_prefix, "fragments")) + graph.load(swcs_path) return graph def load_sites(self, sites_prefix): @@ -192,7 +177,7 @@ def set_valid_branching_nodes(self): # --- Site Retrieval --- def __getitem__(self, idx): node, label = self.get_site(idx) - subgraph = self.graph.rooted_subgraph(node, self.subgraph_radius) + subgraph = self.graph.rooted_subgraph(node, self.subgraph_depth) patches = self.patch_loader(node) return patches, subgraph, label @@ -404,7 +389,7 @@ def __init__( self.modality = modality self.use_shuffle = use_shuffle self.prefetch_batches = prefetch_batches - self.patches_shape = (2,) + dataset.brain_datasets[0].patch_shape + self.patches_shape = (2,) + dataset.brain_datasets[0].patch_loader.patch_shape # Set batch loader if self.is_multimodal and self.modality == "graph": @@ -583,47 +568,49 @@ def _load_image_graph_batch(self, idxs): # --- Sites Loading --- def create_dataset_collection( + brain_ids, img_prefixes_path, - root_path, + sites_root_path, + swcs_root_path, anisotropy=(1.0, 1.0, 1.0), - brightness_clip=500, - subgraph_radius=100, + img_config=None, node_spacing=5, - patch_shape=(128, 128, 128), random_nonmerge_site_prob=0.5, - use_transform=False, + subgraph_depth=100, ): # Load image prefixes - bucket, root_prefix = util.parse_cloud_path(root_path) + bucket, root_prefix = util.parse_cloud_path(sites_root_path) img_prefixes = util.read_json(img_prefixes_path) # Iterate over brains datasets = list() - for subprefix in util.list_gcs_subprefixes(bucket, root_prefix): + for brain_id in brain_ids: # Extract dataset info - brain_id = subprefix.split("/")[-2] - segmentation_id = get_segmentation_id(bucket, subprefix) - dataset_path = os.path.join(root_path, brain_id, segmentation_id) img_path = os.path.join(img_prefixes[brain_id], "0") + segmentation_id = get_segmentation_id(sites_root_path, brain_id) + sites_path = os.path.join(sites_root_path, brain_id, segmentation_id) + swcs_path = util.get_google_swcs_prefix( + swcs_root_path, brain_id, segmentation_id + ) # Add dataset dataset = BrainDataset( brain_id, - dataset_path, img_path, + sites_path, + swcs_path, anisotropy=anisotropy, - brightness_clip=brightness_clip, - subgraph_radius=subgraph_radius, + img_config=img_config, + subgraph_depth=subgraph_depth, node_spacing=node_spacing, - patch_shape=patch_shape, random_nonmerge_site_prob=random_nonmerge_site_prob, - use_transform=use_transform, ) datasets.append(dataset) return BrainDatasetCollection(datasets) -def get_segmentation_id(bucket, prefix): - subprefixes = util.list_gcs_subprefixes(bucket, prefix) +def get_segmentation_id(sites_path, brain_id): + brain_sites_path = os.path.join(sites_path, brain_id) + subprefixes = util.list_gcs_subprefixes(brain_sites_path) assert len(subprefixes) == 1 return subprefixes[0].split("/")[-2] diff --git a/src/neuron_proofreader/merge_proofreading/merge_inference.py b/src/neuron_proofreader/merge_proofreading/merge_inference.py index 474ea3a9..ac1b8ce5 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_inference.py +++ b/src/neuron_proofreader/merge_proofreading/merge_inference.py @@ -244,7 +244,7 @@ def __init__( img_path, patch_shape, batch_size=16, - brightness_clip=300, + brightness_clip=400, is_multimodal=False, min_search_size=0, prefetch=64, diff --git a/src/neuron_proofreader/proposal_graph.py b/src/neuron_proofreader/proposal_graph.py index 37382210..7a9616e8 100644 --- a/src/neuron_proofreader/proposal_graph.py +++ b/src/neuron_proofreader/proposal_graph.py @@ -41,7 +41,6 @@ def __init__( min_cable_length=0, node_spacing=1, prune_depth=20.0, - remove_high_risk_merges=False, verbose=True, ): """ @@ -55,13 +54,10 @@ def __init__( min_cable_length : float, optional Minimum cable length of fragments loaded into graph. Default is 0. node_spacing : float, optional - Distance between points in edges. + Distance (in microns) between nodes. Default is 1. prune_depth : float, optional Branches with length less than "prune_depth" microns are removed. Default is 20um. - remove_high_risk_merges : bool, optional - Indication of whether to remove high risk merge sites (i.e. close - branching points). Default is False. verbose : bool, optional Indication of whether to display a progress bar while building graph. Default is True. diff --git a/src/neuron_proofreader/utils/img_util.py b/src/neuron_proofreader/utils/img_util.py index dfc8a82b..c0972f02 100644 --- a/src/neuron_proofreader/utils/img_util.py +++ b/src/neuron_proofreader/utils/img_util.py @@ -8,7 +8,6 @@ """ -from fastremap import mask_except, renumber, unique from matplotlib.colors import ListedColormap from scipy.ndimage import zoom diff --git a/src/neuron_proofreader/utils/swc_util.py b/src/neuron_proofreader/utils/swc_util.py index 659653a9..c57c2ad0 100644 --- a/src/neuron_proofreader/utils/swc_util.py +++ b/src/neuron_proofreader/utils/swc_util.py @@ -206,7 +206,7 @@ def read_zips(self, zip_paths, read_fn): for process in as_completed(futures): try: swc_dicts.extend(process.result()) - except RefreshError: + except (TransportError, RefreshError): pass if self.verbose: diff --git a/src/neuron_proofreader/utils/util.py b/src/neuron_proofreader/utils/util.py index efb8668c..6621b40d 100644 --- a/src/neuron_proofreader/utils/util.py +++ b/src/neuron_proofreader/utils/util.py @@ -24,7 +24,7 @@ import shutil -# --- OS utils --- +# --- OS Utils --- def listdir(path, extension=None): """ Lists all files in the directory at "path". If an extension is @@ -179,7 +179,7 @@ def set_filename_in_zip(zipfile, name): return filename -# --- IO utils --- +# --- IO Utils --- def combine_zips(zip_paths, output_zip_path): """ Combines a list of ZIP archives into a single ZIP archive. @@ -334,6 +334,31 @@ def write_txt(path, contents): # --- Cloud Utils --- +def get_google_swcs_prefix(root_prefix, brain_id, segmentation_id): + # Determine old vs. new result + prefix1 = os.path.join(root_prefix, brain_id, "whole_brain") + prefix2 = os.path.join(root_prefix, "whole_brain", brain_id) + if check_gcs_exists(prefix1, is_prefix=True): + prefix = prefix1 + elif check_gcs_exists(prefix2, is_prefix=True): + prefix = prefix2 + else: + raise Exception("Unable to find Google swcs result!") + + # Get SWC dirname + prefix = os.path.join(prefix, segmentation_id) + dirname = get_google_swcs_dirname(prefix) + return os.path.join(prefix, dirname) + + +def get_google_swcs_dirname(prefix): + for subprefix in list_gcs_subprefixes(prefix): + dirname = subprefix.split("/")[-2] + if "swc" in dirname: + return dirname + return "swcs" + + def list_cloud_paths(path, extension=""): """ Lists all files in a GCS/S3 bucket with the given extension. @@ -386,26 +411,38 @@ def parse_cloud_path(path): # --- GCS Utils --- -def check_gcs_file_exists(bucket_name, path): +def check_gcs_exists(path, is_prefix=False): """ - Checks if the given path exists. - + Checks if a file or prefix exists in GCS. Parameters ---------- - bucket_name : str - Name of bucket to be checked. path : str - Path to be checked. - + GCS path to check. + prefix : bool + If True, checks whether any object exists under the given prefix. + If False, checks whether the exact file exists. Returns ------- bool Indication of whether the path exists. """ - client = storage.Client() - bucket = client.bucket(bucket_name) - blob = bucket.blob(path) - return blob.exists() + bucket_name, key = parse_cloud_path(path) + bucket = storage.Client().bucket(bucket_name) + if is_prefix: + key = key.rstrip("/") + "/" + return any(bucket.list_blobs(prefix=key, max_results=1)) + else: + return bucket.blob(key).exists() + + +def check_gcs_prefix_exists(path): + bucket_name, prefix = parse_cloud_path(path) + prefix = prefix.rstrip("/") + "/" + bucket = storage.Client().bucket(bucket_name) + exists = any( + bucket.list_blobs(prefix=prefix, max_results=1) + ) + return exists def is_gcs_path(path): @@ -451,22 +488,23 @@ def list_gcs_paths(bucket_name, prefix, extension=""): return paths -def list_gcs_subprefixes(bucket_name, prefix): +def list_gcs_subprefixes(path): """ - Lists all direct subdirectories of a given prefix in a GCS bucket. + Lists all direct subdirectories of a given location in a GCS bucket. Parameters ---------- - bucket_name : str - Name of bucket containing prefix. - prefix : str - Path to location within bucket to be searched. + path : str + Path to location in a GCS bucket. Returns ------- List[str] Direct subdirectories. """ + bucket_name, prefix = parse_cloud_path(path) + prefix = prefix.rstrip("/") + "/" + # Load blobs blobs = storage.Client().list_blobs( bucket_name, prefix=prefix, delimiter="/" @@ -641,7 +679,7 @@ def upload_file_to_s3(src_path, bucket_name, dst_path): s3.upload_file(src_path, bucket_name, dst_path) -# --- Dictionary utils --- +# --- Dictionary Utils --- def find_best(my_dict, maximize=True): """ Finds the key associated with the largest integer or longest list. From f759b0b4393f7b71c71cbb93b6306967f0582987 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Wed, 3 Jun 2026 20:47:10 +0000 Subject: [PATCH 06/11] refactor: configs --- src/neuron_proofreader/config.py | 211 ---- src/neuron_proofreader/configs.py | 124 ++ .../machine_learning/image_dataloader.py | 10 +- ...datamodules_v2.py => merge_datamodules.py} | 52 +- .../merge_proofreading/merge_datasets.py | 1107 ----------------- src/neuron_proofreader/utils/util.py | 4 +- 6 files changed, 160 insertions(+), 1348 deletions(-) delete mode 100644 src/neuron_proofreader/config.py create mode 100644 src/neuron_proofreader/configs.py rename src/neuron_proofreader/merge_proofreading/{merge_datamodules_v2.py => merge_datamodules.py} (94%) delete mode 100644 src/neuron_proofreader/merge_proofreading/merge_datasets.py diff --git a/src/neuron_proofreader/config.py b/src/neuron_proofreader/config.py deleted file mode 100644 index eb8932b2..00000000 --- a/src/neuron_proofreader/config.py +++ /dev/null @@ -1,211 +0,0 @@ -""" -Created on Frid Sept 15 16:00:00 2024 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -This module defines a set of configuration classes used for setting storing -parameters used in neuron proofreading pipelines. - -""" - -from dataclasses import dataclass -from typing import Tuple - -import os - -from neuron_proofreader.utils import util - - -@dataclass -class GraphConfig: - pass - - -@dataclass -class ImageConfig: - """ - Configuration class for image processing parameters. - - Attributes - ---------- - brightness_clip : int - Intensity value that voxel brightness is clipped to. - percentiles : Tuple[float], optional - Percentiles used to normalize patches. - patch_shape : Tuple[int] - Shape of patch to be read from image. - transform : bool - Indication of whether to use image augmentation. - """ - brightness_clip: int = 400 - percentiles: tuple = (1, 99.5) - patch_shape: tuple = (128, 128, 128) - transform: bool = False - - -@dataclass -class MLConfig: - """ - Configuration class for machine learning model parameters. - - Attributes - ---------- - batch_size : int - The number of samples processed in one batch during training or - inference. Default is 64. - brightness_clip : int - Maximum brightness value that image intensities are clipped to. - Default is 400. - device : str - Device to load model onto. Default is "cuda". - model_name : str - Name of model used to perform inference. Default is None. - patch_shape : Tuple[int] - Shape of image patch expected by vision model. Default is (96, 96, 96). - transform : bool - Indication of whether to apply data augmentation to image patches. - Default is False. - """ - batch_size: int = 64 - brightness_clip: int = 400 - device: str = "cuda" - model_name: str = None - patch_shape: tuple = (96, 96, 96) - transform: bool = False - - def to_dict(self): - """ - Converts configuration attributes to a dictionary. - - Returns - ------- - dict - Dictionary containing configuration attributes. - """ - attributes = dict() - for k, v in vars(self).items(): - if isinstance(v, tuple): - attributes[k] = list(v) - else: - attributes[k] = v - return attributes - - def save(self, path): - """ - Saves configuration attributes to a JSON file. - """ - util.write_json(path, self.to_dict()) - - -@dataclass -class ProposalsConfig: - pass - - -@dataclass -class ProposalGraphConfig: - """ - Represents configuration settings related to graph properties and - proposals generated. - - Attributes - ---------- - allow_nonleaf_proposals : bool, optional - Indication of whether to generate proposals between leaf and nodes - with degree 2. Default is False. - anisotropy : Tuple[float] - Scaling factors used to transform physical to image coordinates - Default is (1.0, 1.0, 1.0). - max_proposals_per_leaf : int - Maximum number of proposals generated at leaf nodes. Default is 3. - min_cable_length : float - Minimum path length (in microns) of SWC files loaded into a graph - object. Default is 40. - node_spacing : float - Physcial spacing (in microns) between nodes. Default is 1. - prune_depth : int - Branches in graph less than "prune_depth" microns are pruned. Default - is 24. - remove_doubles : bool - Indication of whether to remove fragments that are likely a double of - another fragment. Default is True. - remove_high_risk_merges : bool - Indication of whether to remove high risk merge sites (i.e. close - branching points). Default is False. - trim_endpoints_bool : bool - Indication of whether trim endpoints of branches with exactly one - proposal. Default is True. - verbose : bool - Indication of whether to display a progress bar. Default is True. - """ - - allow_nonleaf_proposals: bool = False - anisotropy: Tuple[float, float, float] = (1.0, 1.0, 1.0) - max_proposals_per_leaf: int = 3 - min_cable_length: float = 40.0 - node_spacing: float = 1.0 - proposals_per_leaf: int = 3 - prune_depth: float = 24.0 - remove_doubles: bool = True - remove_high_risk_merges: bool = False - trim_endpoints_bool: bool = True - verbose: bool = True - - def to_dict(self): - """ - Converts configuration attributes to a dictionary. - - Returns - ------- - dict - Dictionary containing configuration attributes. - """ - attributes = dict() - for k, v in vars(self).items(): - if isinstance(v, tuple): - attributes[k] = list(v) - else: - attributes[k] = v - return attributes - - def save(self, path): - """ - Saves configuration attributes to a JSON file. - """ - util.write_json(path, self.to_dict()) - - -@dataclass -class Config: - """ - A configuration class for managing and storing settings related to graph - and machine learning models. - """ - - def __init__(self, graph_config, ml_config): - """ - Initializes a Config object which is used to manage settings used to - run the proofreading pipeline. - - Parameters - ---------- - graph_config : GraphConfig - Instance of the "GraphConfig" class that contains configuration - parameters for graph and proposal operations. - ml_config : MLConfig - An instance of the "MLConfig" class that includes configuration - parameters for machine learning models. - """ - self.graph = graph_config - self.ml = ml_config - - def save(self, dir_path): - """ - Saves configuration attributes to a JSON file. - - dir_path : str - Path to directory to save JSON file. - """ - self.graph.save(os.path.join(dir_path, "metadata_graph.json")) - self.ml.save(os.path.join(dir_path, "metadata_ml.json")) diff --git a/src/neuron_proofreader/configs.py b/src/neuron_proofreader/configs.py new file mode 100644 index 00000000..dc6ea5c6 --- /dev/null +++ b/src/neuron_proofreader/configs.py @@ -0,0 +1,124 @@ +""" +Created on Frid Sept 15 16:00:00 2024 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Connfiguration classes used for setting storing parameters used in neuron +proofreading pipelines. + +""" + +from abc import ABC +from dataclasses import asdict, dataclass +from typing import Tuple + +import os + +from neuron_proofreader.utils import util + + +class Config(ABC): + + def to_dict(self): + """ + Converts configuration attributes to a dictionary. + + Returns + ------- + attrs : dict + Dictionary containing configuration attributes. + """ + attrs = asdict(self) + for k, v in attrs.items(): + if isinstance(v, tuple): + attrs[k] = list(v) + return attrs + + def save(self, output_dir): + """ + Saves configuration attributes to a JSON file. + + dir_path : str + Path to directory to save JSON file. + """ + path = os.path.join(output_dir, f"{self.name}.json") + util.write_json(path, self.to_dict()) + + +@dataclass +class GraphConfig: + """ + Configuration class for skeleton graph parameters. + + Attributes + ---------- + anisotropy : Tuple[float] + Scaling factors used to transform physical to image coordinates. + min_cable_length : float + Minimum path length (in microns) of SWC files loaded into graph. + node_spacing : float + Physcial spacing (in microns) between nodes. + prune_depth : int + ... + remove_doubles : bool + Indication of whether to remove fragments that are likely a double of + another. + verbose : bool + Indication of whether to display a progress bar. + """ + + anisotropy: Tuple[float, float, float] = (1.0, 1.0, 1.0) + min_cable_length: float = 0.0 + name: str = "graph_config" + node_spacing: float = 1.0 + prune_depth: float = 20.0 + remove_doubles: bool = True + use_anisotropy: bool = True + verbose: bool = False + + +@dataclass +class ImageConfig: + """ + Configuration class for image processing parameters. + + Attributes + ---------- + brightness_clip : int + Intensity value that voxel brightness is clipped to. + percentiles : Tuple[float], optional + Percentiles used to normalize patches. + patch_shape : Tuple[int] + Shape of patch to be read from image. + transform : bool + Indication of whether to use image augmentation. + """ + + brightness_clip: int = 400 + name: str = "image_config" + percentiles: Tuple[float, float] = (1, 99.5) + patch_shape: Tuple[int, int, int] = (128, 128, 128) + transform: bool = False + + +@dataclass +class ProposalsConfig: + """ + Configuration class for skeleton graph parameters. + + Attributes + ---------- + allow_nonleaf_proposals : bool + Indication of whether to generate proposals between leaf and nodes + with degree 2. + proposals_per_leaf : int + Maximum number of proposals generated at leaf nodes. + trim_endpoints_bool : bool + Indication of whether trim endpoints of isolated leaf-to-leaf + proposals. + """ + + allow_nonleaf_proposals: bool = False + proposals_per_leaf: int = 3 + trim_endpoints_bool: bool = True diff --git a/src/neuron_proofreader/machine_learning/image_dataloader.py b/src/neuron_proofreader/machine_learning/image_dataloader.py index ae6deba0..a9035b0b 100644 --- a/src/neuron_proofreader/machine_learning/image_dataloader.py +++ b/src/neuron_proofreader/machine_learning/image_dataloader.py @@ -13,8 +13,9 @@ import numpy as np import tensorstore as ts +from neuron_proofreader.configs import ImageConfig from neuron_proofreader.machine_learning.image_augmentation import ( - ImageTransforms + ImageTransforms, ) from neuron_proofreader.utils import geometry_util, img_util, util @@ -108,14 +109,15 @@ def __init__(self, graph, img_config, img_path): ---------- graph : SkeletonGraph Graph used to compute patch voxel coordinates. + img_config : ImageConfig or None + Config object with image processing parameters. img_path : str Path to whole-brain image. - """ - self.config = img_config + self.config = img_config or ImageConfig() self.graph = graph self.img = TensorStoreImage(img_path) - self.transform = ImageTransforms() if img_config.transform else None + self.transform = ImageTransforms() if self.config.transform else None # --- Abstract Interface --- @abstractmethod diff --git a/src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py similarity index 94% rename from src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py rename to src/neuron_proofreader/merge_proofreading/merge_datamodules.py index 21776d86..d34be19c 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datamodules_v2.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py @@ -20,7 +20,7 @@ the correct brain, exposes split() for train/val partitioning, and is the object handed to MergeSiteDataLoader. -MergeSiteDataLoader +ThreadedDataLoader Custom DataLoader that uses multithreading to fetch image patches from cloud storage and assemble batches. """ @@ -72,7 +72,7 @@ class BrainDataset: Radius (in microns) used when extracting rooted subgraphs. """ - giant_component_cable_length = 10**4 + giant_component_cable_length = 3 * 10**4 def __init__( self, @@ -80,20 +80,21 @@ def __init__( img_path, sites_prefix, swcs_path, - anisotropy=(1.0, 1.0, 1.0), + graph_config=None, img_config=None, - node_spacing=5, random_nonmerge_site_prob=0.5, + rebalance_classes=True, subgraph_depth=100, ): # Instance attributes self.brain_id = brain_id + self.rebalance_classes = rebalance_classes self.ignore_fragments = set() self.random_nonmerge_site_prob = random_nonmerge_site_prob self.subgraph_depth = subgraph_depth # Core data structures - self.graph = self.load_fragments(anisotropy, node_spacing, swcs_path) + self.graph = self.load_fragments(graph_config, swcs_path) self.merge_sites = self.load_sites( os.path.join(sites_prefix, "merge_sites") ) @@ -107,12 +108,13 @@ def __init__( self.set_merge_site_info() self.set_valid_branching_nodes() - def load_fragments(self, anisotropy, node_spacing, swcs_path): + def load_fragments(self, config, swcs_path): graph = SkeletonGraph( - anisotropy=anisotropy, - node_spacing=node_spacing, - use_anisotropy=False, - verbose=True, + anisotropy=config.anisotropy, + min_cable_length=config.min_cable_length, + node_spacing=config.node_spacing, + use_anisotropy=config.use_anisotropy, + verbose=config.verbose, ) graph.load(swcs_path) return graph @@ -208,7 +210,7 @@ def add_nonmerge_sites(self, num_sites): site = { "node": node, "xyz": self.graph.node_xyz[node], - "filename": "random" + "filename": "random", } new_sites.append(site) @@ -241,15 +243,16 @@ def _has_nearby_branching(self, root, max_depth=60): def _list_indices(self): # Set idxs pos_idxs = np.arange(len(self.merge_sites)) - neg_idxs = np.arange(len(self.nonmerge_sites)) + neg_idxs = -np.arange(len(self.nonmerge_sites)) # Check for class imbalance - if len(neg_idxs) < len(pos_idxs): - neg_idxs = -pos_idxs - else: - neg_idxs = -np.random.choice( - neg_idxs, size=len(pos_idxs), replace=False - ) + if self.rebalance_classes: + if len(neg_idxs) < len(pos_idxs): + neg_idxs = -pos_idxs + else: + neg_idxs = np.random.choice( + neg_idxs, size=len(pos_idxs), replace=False + ) return np.concatenate((pos_idxs, neg_idxs)) def __len__(self): @@ -331,6 +334,7 @@ def get_idxs(self): Returns ------- numpy.ndarray + Indices over the full index table. """ return np.arange(len(self._index_table)) @@ -389,7 +393,9 @@ def __init__( self.modality = modality self.use_shuffle = use_shuffle self.prefetch_batches = prefetch_batches - self.patches_shape = (2,) + dataset.brain_datasets[0].patch_loader.patch_shape + self.patches_shape = (2,) + dataset.brain_datasets[ + 0 + ].patch_loader.patch_shape # Set batch loader if self.is_multimodal and self.modality == "graph": @@ -572,10 +578,10 @@ def create_dataset_collection( img_prefixes_path, sites_root_path, swcs_root_path, - anisotropy=(1.0, 1.0, 1.0), + graph_config=None, img_config=None, - node_spacing=5, random_nonmerge_site_prob=0.5, + rebalance_classes=False, subgraph_depth=100, ): # Load image prefixes @@ -599,11 +605,11 @@ def create_dataset_collection( img_path, sites_path, swcs_path, - anisotropy=anisotropy, + graph_config=graph_config, img_config=img_config, subgraph_depth=subgraph_depth, - node_spacing=node_spacing, random_nonmerge_site_prob=random_nonmerge_site_prob, + rebalance_classes=rebalance_classes, ) datasets.append(dataset) return BrainDatasetCollection(datasets) diff --git a/src/neuron_proofreader/merge_proofreading/merge_datasets.py b/src/neuron_proofreader/merge_proofreading/merge_datasets.py deleted file mode 100644 index 63a838e0..00000000 --- a/src/neuron_proofreader/merge_proofreading/merge_datasets.py +++ /dev/null @@ -1,1107 +0,0 @@ -""" -Created on Wed July 2 11:00:00 2025 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Dataset and dataloader utilities for processing merge site data to train model -to detect merge errors. - -""" - -from concurrent.futures import as_completed, ThreadPoolExecutor -from scipy.spatial import KDTree -from torch.utils.data import Dataset, DataLoader - -import copy -import networkx as nx -import numpy as np -import os -import pandas as pd -import random -import torch - -from neuron_proofreader.machine_learning.augmentation import ImageTransforms -from neuron_proofreader.machine_learning.geometric_gnn_models import ( - subgraph_to_data, -) -from neuron_proofreader.machine_learning.point_cloud_models import ( - subgraph_to_point_cloud, -) -from neuron_proofreader.merge_proofreading.merge_dataloading import ( - get_brain_merge_sites, -) -from neuron_proofreader.skeleton_graph import SkeletonGraph -from neuron_proofreader.utils import ( - geometry_util, - img_util, - ml_util, - util, -) - - -# --- Datasets --- -class MergeSiteDataset(Dataset): - """ - Dataset class for loading and processing merge site data. The core data - structure is the attribute "merge_sites_df" which contains metadata about - each merge site. - - Attributes - ---------- - anisotropy : Tuple[float], optional - Image to physical coordinates scaling factors to account for the - anisotropy of the microscope. - subgraph_radius : int, optional - Radius (in microns) around merge sites used to extract rooted - subgraphs. - gt_graphs : Dict[str, SkeletonGraph] - Dictionary that maps brain IDs to a graph containing ground truth - skeletons. - graphs : Dict[str, SkeletonGraph] - Dictionary that maps brain IDs to graphs containing fragments that - have merge mistakes. - img_readers : Dict[str, ImageReader] - Image readers used to read raw images from cloud bucket. - merge_sites_df : pandas.DataFrame - DataFrame containing merge sites, must contain the columns: "brain_id" - "segmentation_id", "segment_id", and "xyz". - node_spacing : int, optional - Spacing (in microns) between neighboring nodes in graphs. - patch_shape : Tuple[int], optional - Shape of the 3D image patches to extract. - """ - - random_negative_example_prob = 0.8 - - def __init__( - self, - merge_sites_df, - anisotropy=(1.0, 1.0, 1.0), - brightness_clip=400, - subgraph_radius=100, - node_spacing=5, - patch_shape=(128, 128, 128), - use_new_mask=False, - ): - """ - Instantiates a MergeSiteDataset object. - - Parameters - ---------- - merge_sites_df : pandas.DataFrame - DataFrame containing merge sites, must contain the columns: - "brain_id", "segmentation_id", "segment_id", and "xyz". - anisotropy : Tuple[float], optional - Image to physical coordinates scaling factors to account for the - anisotropy of the microscope. Default is (1.0, 1.0, 1.0). - brightness_clip : int, optional - ... - subgraph_radius : int, optional - Radius (in microns) around merge sites used to extract rooted - subgraph. Default is 100μm. - node_spacing : int, optional - Spacing between nodes in the graph. Default is 5μm. - patch_shape : Tuple[int], optional - Shape of the 3D patches to extract. Default is (128, 128, 128). - """ - # Instance attributes - self.anisotropy = anisotropy - self.brightness_clip = brightness_clip - self.node_spacing = node_spacing - self.merge_sites_df = merge_sites_df - self.patch_shape = patch_shape - self.subgraph_radius = subgraph_radius - self.use_new_mask = use_new_mask - - # Data structures - self.graphs = dict() - self.gt_graphs = dict() - self.img_readers = dict() - self.segmentation_readers = dict() - self.merge_site_kdtrees = dict() - - # --- Load Data --- - def load_fragment_graphs(self, brain_id, swc_pointer, use_anisotropy=True): - """ - Loads fragments containing merge mistakes for a whole-brain dataset, - then stores them in the "graphs" attribute. - - Parameters - ---------- - brain_id : str - Unique identifier for a whole-brain dataset. - swc_pointer : str - Pointer to SWC files to be loaded into a graph. - """ - # Load graphs - graph = SkeletonGraph( - anisotropy=self.anisotropy, - node_spacing=self.node_spacing, - use_anisotropy=use_anisotropy, - ) - graph.load(swc_pointer) - - # Remove groundtruth skeletons - for swc_id in graph.swc_ids(): - if swc_id.lower().startswith("n"): - component_id = graph.component_id_from_swc_id(swc_id) - nodes = graph.nodes_with_component_id(component_id) - graph.remove_nodes(nodes, relabel_nodes=False) - - # Remove fragments excluded from merge sites - brain_idxs = self.merge_sites_df["brain_id"] == brain_id - merge_sites = self.merge_sites_df[brain_idxs] - segment_ids = set(merge_sites["segment_id"].unique()) - for nodes in map(list, list(nx.connected_components(graph))): - node = util.sample_once(nodes) - segment_id = graph.node_segment_id(node) - if segment_id not in segment_ids: - graph.remove_nodes(nodes, relabel_nodes=False) - graph.relabel_nodes() - - # Build merge site kdtrees - pts = get_brain_merge_sites(self.merge_sites_df, brain_id) - self.merge_site_kdtrees[brain_id] = KDTree(pts) - - # Post process fragments - self.clip_fragments_to_groundtruth(brain_id, graph) - self.graphs[brain_id] = graph - - def load_gt_graphs(self, brain_id, swc_pointer): - """ - Loads ground truth skeletons for a whole-brain dataset, then stores - them in the "gt_graphs" attribute. - - Parameters - ---------- - brain_id : str - Unique identifier for a whole-brain dataset. - swc_pointer : str - Pointer to SWC files to be loaded into graph. - """ - self.gt_graphs[brain_id] = SkeletonGraph( - anisotropy=self.anisotropy, - node_spacing=self.node_spacing, - ) - self.gt_graphs[brain_id].load(swc_pointer) - self.gt_graphs[brain_id].set_kdtree() - - def load_images(self, brain_id, img_path, segmentation_path): - """ - Loads image reader for a whole-brain dataset, then stores it in the - "img_readers" attribute. - - Parameters - ---------- - brain_id : str - Unique identifier for a whole-brain dataset. - img_path : str - Path to whole-brain image. - segmentation_path : str - Path to segmentation of whole-brain image. - """ - self.img_readers[brain_id] = img_util.TensorStoreReader(img_path) - self.segmentation_readers[brain_id] = img_util.TensorStoreReader( - segmentation_path - ) - - # --- Create Subclass Dataset --- - def subset(self, cls, idxs): - """ - Creates a derived dataset keeping only specified indices. - - Parameters - ---------- - cls : class - Class of the new dataset. - idxs : List[int] - Indices of merge sites to keep. - - Returns - ------- - new_dataset : cls - New dataset instance containing only the specified subset. - """ - new_dataset = cls.__new__(cls) - new_dataset.__dict__ = copy.deepcopy(self.__dict__) - new_dataset.remove_nonindexed_fragments(idxs) - new_dataset.remove_isolated_sites() - return new_dataset - - def remove_isolated_sites(self): - """ - Removes merge sites whose closest fragment is greater than a specified - distance. - """ - # Find non-isolated sites - idxs = list() - for i in range(len(self.merge_sites_df)): - brain_id = self.merge_sites_df["brain_id"][i] - xyz = self.merge_sites_df["xyz"][i] - if brain_id in self.graphs: - d, _ = self.graphs[brain_id].kdtree.query(xyz) - if d < 10: - idxs.append(i) - - # Drop isolated sites - self.merge_sites_df = self.merge_sites_df.iloc[idxs] - self.merge_sites_df = self.merge_sites_df.reset_index(drop=True) - - def remove_nonindexed_fragments(self, idxs): - """ - Removes fragments that do not correspond to the given site indices. - - Parameters - ---------- - idxs : List[int] - Indices of merge sites to keep. Fragments associated with all - other sites are removed. - """ - # Remove other fragments - visited = set() - for i in [i for i in self.merge_sites_df.index if i not in idxs]: - # Extract site info - brain_id = self.merge_sites_df["brain_id"][i] - segment_id = self.merge_sites_df["segment_id"][i] - pair = (brain_id, segment_id) - - # Find fragment containing site - if pair not in visited: - nodes = self.graphs[brain_id].nodes_with_segment_id(segment_id) - self.graphs[brain_id].remove_nodes(nodes, False) - visited.add(pair) - - self.remove_empty_graphs() - - # Relabel nodes - for brain_id in self.graphs: - self.graphs[brain_id].relabel_nodes() - - # Update merge sites df - self.merge_sites_df = self.merge_sites_df.iloc[idxs] - self.merge_sites_df = self.merge_sites_df.reset_index(drop=True) - - def remove_empty_graphs(self): - """ - Removes graphs without any nodes. - """ - for brain_id in list(self.graphs.keys()): - if len(self.graphs[brain_id].nodes) == 0: - del self.graphs[brain_id] - - # --- Getters --- - def __getitem__(self, idx): - """ - Gets the example corresponding to the given index, which consists of - an image patch, label mask, and rooted subgraph. - - Parameters - ---------- - idx : int - Index of example to retrieve. Positive indices correspond to merge - sites, while non-positive indices correspond to non-merge sites. - - Returns - ------- - patches : numpy.ndarray - Array containing the image patch and segment mask with shape - (2, D, H, W). - subgraph : networkx.Graph - Rooted subgraph centered at the site node. - label : int - 1 if the example is positive and 0 otherwise. - """ - # Get example - brain_id, subgraph, label = self.get_site(idx) - voxel = subgraph.node_voxel(0) - - # Extract subgraph and image patches centered at site - img_patch = self.get_img_patch(brain_id, voxel) - segment_mask = self.get_segment_mask(brain_id, voxel, subgraph) - - # Stack image channels - try: - patches = np.stack([img_patch, segment_mask], axis=0) - except ValueError: - img_patch = img_util.pad_to_shape(img_patch, self.patch_shape) - patches = np.stack([img_patch, segment_mask], axis=0) - return patches, subgraph, label - - def sample_brain_id(self): - """ - Samples a brain ID. - - Returns - ------- - brain_id : str - Unique identifier of a whole-brain dataset. - """ - while True: - brain_id = util.sample_once(list(self.graphs.keys())) - if len(self.graphs[brain_id].nodes) > 0: - return brain_id - - def get_indexed_negative_site(self, idx): - """ - Gets the negative example corresponding to the given index. - - Parameters - ---------- - idx : int - Index of the site in "sites_df". - - Returns - ------- - brain_id : str - Unique identifier for the whole-brain dataset containing the site. - node : int - Node ID of the site. - label : int - Label of example. - """ - # Get site info - brain_id = self.merge_sites_df["brain_id"].iloc[idx] - xyz = self.merge_sites_df["xyz"].iloc[idx] - node = self.gt_graphs[brain_id].kdtree.query(xyz)[1] - - # Extract rooted subgraph - subgraph = self.gt_graphs[brain_id].rooted_subgraph( - node, self.subgraph_radius - ) - return brain_id, subgraph, 0 - - def get_indexed_positive_site(self, idx): - """ - Gets the positive example corresponding to the given index. - - Parameters - ---------- - idx : int - Index of the site in "sites_df". - - Returns - ------- - brain_id : str - Unique identifier for the whole-brain dataset containing the site. - node : int - Node ID of the site. - label : int - Label of example. - """ - # Get site info - brain_id = self.merge_sites_df["brain_id"].iloc[idx] - xyz = self.merge_sites_df["xyz"].iloc[idx] - node = self.graphs[brain_id].kdtree.query(xyz)[1] - - # Extract rooted subgraph - subgraph = self.graphs[brain_id].rooted_subgraph( - node, self.subgraph_radius - ) - return brain_id, subgraph, 1 - - def get_random_negative_site(self): - """ - Gets a random non-merge site from a fragment graph. - - Returns - ------- - brain_id : str - Unique identifier of the whole-brain dataset containing the site. - node : int - Node ID of the site. - label : int - Label of example. - """ - brain_id = self.sample_brain_id() - outcome = random.random() - cnt = 0 - while True: - # Sample node - cnt += 1 - if outcome < 0.4: - # Any node - node = util.sample_once(self.graphs[brain_id].nodes) - elif outcome < 0.8: - # Branching node - branching_nodes = self.graphs[brain_id].branching_nodes() - if len(branching_nodes) > 0: - node = util.sample_once(branching_nodes) - else: - outcome = 0 - continue - else: - # Branching node from GT - branching_nodes = self.gt_graphs[brain_id].branching_nodes() - node = util.sample_once(branching_nodes) - subgraph = self.gt_graphs[brain_id].rooted_subgraph( - node, self.subgraph_radius - ) - return brain_id, subgraph, 0 - - # Extract rooted subgraph - subgraph = self.graphs[brain_id].rooted_subgraph( - node, self.subgraph_radius - ) - - # Check branching - if self.graphs[brain_id].degree(node) > 2: - is_high_degree = self.graphs[brain_id].degree(node) > 3 - is_too_branchy = self.check_nearby_branching(brain_id, node) - if is_high_degree or is_too_branchy: - continue - - # Check if node is close to merge site - if not self.is_nearby_merge_site(brain_id, node): - return brain_id, subgraph, 0 - - # Check number of tries - if cnt > 20: - outcome = 1 - - def get_img_patch(self, brain_id, center): - """ - Extracts and normalizes a 3D image patch from the specified whole- - brain dataset. - - Parameters - ---------- - brain_id : str - Unique identifier of the whole-brain dataset to read from. - center : numpy.ndarray - Voxel coordinates of the patch center. - - Returns - ------- - img_patch : numpy.ndarray - Extracted image patch, which has been normalized and clipped to a - maximum value of "self.brightness_clip". - """ - img_patch = self.img_readers[brain_id].read(center, self.patch_shape) - img_patch = np.minimum(img_patch, self.brightness_clip) - return img_util.normalize(img_patch) - - def get_segment_mask(self, brain_id, center, subgraph): - """ - Generates a binary mask for a given subgraph within a patch. - - Parameters - ---------- - subgraph : SkeletonGraph - Rooted subgraph centered at the site node. - - Returns - ------- - segment_mask : numpy.ndarray - Binary mask for a given subgraph within a patch. - """ - # Read segmentation - if self.use_new_mask: - segment_mask = self.segmentation_readers[brain_id].read( - center, self.patch_shape - ) - segment_mask = img_util.remove_small_segments(segment_mask, 1000) - segment_mask = 0.5 * (segment_mask > 0).astype(float) - else: - segment_mask = np.zeros(self.patch_shape) - - # Annotate fragment - center = subgraph.node_voxel(0) - offset = img_util.get_offset(center, self.patch_shape) - for node1, node2 in subgraph.edges: - # Get local voxel coordinates - voxel1 = subgraph.node_local_voxel(node1, offset) - voxel2 = subgraph.node_local_voxel(node2, offset) - - # Populate mask - voxels = geometry_util.make_digital_line(voxel1, voxel2) - img_util.annotate_voxels(segment_mask, voxels) - return segment_mask - - # --- Helpers --- - def __len__(self): - """ - Returns the number of positive and negative examples of merge sites. - - Returns - ------- - int - Number of positive examples of merge sites. - """ - return 2 * len(self.merge_sites_df) - - def check_nearby_branching( - self, brain_id, root, max_depth=60, use_gt=False - ): - """ - Checks if there is a branching node within a specified depth from the - given node. - - Parameters - ---------- - brain_id : str - Unique identifier for graph to be searched. - root : int - Node ID. - max_depth : float, optional - Maximum depth (in microns) of search. Default is 20μm. - use_gt : bool - Indication of whether to check groundtruth graph. Default is - False. - - Returns - ------- - bool - Indication of whether there is a nearby branching node. - """ - graph = self.gt_graphs[brain_id] if use_gt else self.graphs[brain_id] - queue = [(root, 0)] - visited = set([root]) - while queue: - # Visit node - i, d_i = queue.pop() - if graph.degree[i] > 2 and d_i > 0: - return True - - # Update queue - for j in graph.neighbors(i): - d_j = d_i + graph.dist(i, j) - if j not in visited and d_j < max_depth: - queue.append((j, d_j)) - visited.add(j) - return False - - def clip_fragments_to_groundtruth(self, brain_id, graph): - """ - Removes any node from the given fragment that is more than 100μm from - the ground truth graph. - - Parameters - ---------- - brain_id : str - Unique identifier for a whole-brain dataset. - graph : SkeletonGraph - Fragment graph to be clipped. - """ - assert brain_id in self.gt_graphs, "Must load GT before fragments!" - d_gt, _ = self.gt_graphs[brain_id].kdtree.query(graph.node_xyz) - nodes = np.where(d_gt > 100)[0] - graph.remove_nodes(nodes) - - def count_fragments(self): - """ - Counts the number of fragments in the dataset. - - Returns - ------- - cnt : int - Number of fragments in the dataset. - """ - cnt = 0 - for graph in self.graphs.values(): - cnt += nx.number_connected_components(graph) - return cnt - - def is_nearby_merge_site(self, brain_id, node): - """ - Checks whether to the given node is close to a merge site. - - Parameters - ---------- - brain_id : str - Unique identifier for graph to be searched. - node : int - Node ID to check if it's close to a merge site. - """ - xyz = self.graphs[brain_id].node_xyz[node] - dist, _ = self.merge_site_kdtrees[brain_id].query(xyz) - return dist < 100 - - def sample_node_nearby_soma(self, brain_id): - subgraph = self.gt_graphs[brain_id].get_rooted_subgraph(0, 600) - gt_node = util.sample_once(subgraph.nodes) - gt_xyz = self.gt_graphs[brain_id].node_xyz[gt_node] - d, node = self.graphs[brain_id].kdtree.query(gt_xyz) - return node - - -class MergeSiteTrainDataset(MergeSiteDataset): - """ - A class for storing and retrieving training examples. - """ - - def __init__(self, base_dataset=None, idxs=None, negative_bias=0): - """ - Instantiates a MergeSiteTrainDataset object. - - Parameters - ---------- - base_dataset : MergeSiteDataset, optional - Dataset to be instantiated as a train dataset. - idxs : List[int], optional - Indices of examples to be kept in train dataset. - negative_bias : float, optional - Specifies percentage of additional negative examples to add. - """ - # Create sub-dataset - subset_dataset = base_dataset.subset(self.__class__, idxs) - self.__dict__.update(subset_dataset.__dict__) - - # Instance attributes - self.negative_bias = negative_bias - self.transform = ImageTransforms() - - # --- Getters --- - def __getitem__(self, idx): - """ - Gets the example specified by the given index. - - Parameters - ---------- - idx : int - Index of example. - - Returns - ------- - patches : numpy.ndarray - Array of stacked channels containing the image patch and label - mask with shape (2, D, H, W). - subgraph : SkeletonGraph - Rooted subgraph centered at the site node. - label : int - 1 if the example is positive and 0 otherwise. - """ - patches, subgraph, label = super().__getitem__(idx) - self.transform(patches) - return patches, subgraph, label - - def get_site(self, idx): - """ - Retrieves a merge or nonmerge site specified by the given index. - - Parameters - ---------- - idx : int - Index of site to retrieve. Positive indices correspond to merge - sites, non-positive indices correspond to non-merge sites. - - Returns - ------- - brain_id : str - Unique identifier of the brain containing the site. - node : int - Node ID of the site. - label : int - 1 if the example is positive and 0 otherwise. - """ - if idx > 0: - return self.get_indexed_positive_site(idx) - elif np.random.random() < self.random_negative_example_prob: - return self.get_random_negative_site() - elif abs(idx) < len(self.merge_sites_df): - return self.get_indexed_negative_site(abs(idx)) - else: - return self.get_random_negative_site() - - # --- Helpers --- - def get_idxs(self): - """ - Gets example indices to iterate over. - - Returns - ------- - numpy.ndarray - Example indices to iterate over. - """ - n_pos_examples = len(self.merge_sites_df) - n_negative_examples = int(n_pos_examples * (1 + self.negative_bias)) - return np.arange(-n_negative_examples + 1, n_pos_examples) - - -class MergeSiteValDataset(MergeSiteDataset): - """ - A class for storing and retrieving validation examples. - """ - - def __init__(self, base_dataset=None, idxs=None): - """ - Instantiates a MergeSiteValDataset object. - - Parameters - ---------- - base_dataset : MergeSiteDataset, optional - Dataset to be instantiated as a validation dataset. - idxs : List[int], optional - Indices of examples to be kept in validation dataset. - """ - # Create sub-dataset - subset_dataset = base_dataset.subset(self.__class__, idxs) - self.__dict__.update(subset_dataset.__dict__) - - # Instance attributes - self.examples = self.generate_examples() - self.examples_summary = self.set_examples_summary() - - def generate_examples(self): - # Generate negative examples - negative_examples = self.generate_negative_examples() - - # Generate positive examples - positive_examples = list() - for i in range(len(self.merge_sites_df)): - brain_id, subgraph, _ = self.get_indexed_positive_site(i) - positive_examples.append( - { - "brain_id": brain_id, - "subgraph": subgraph, - "xyz": subgraph.node_xyz[0], - "label": 1, - } - ) - return positive_examples + negative_examples - - def generate_negative_examples(self): - """ - Generates examples of non-merge sites by sampling points on fragments - that are sufficiently far from a merge site. - - Returns - ------- - negative_examples : List[dict] - List of negative examples collected across all graphs. - """ - - # Subroutines - def add_examples(): - """ - Adds the given example to the set of validation examples. - """ - for node in random.sample(nodes, n_examples): - # Check if close to merge site - if not self.is_nearby_merge_site(brain_id, node): - subgraph = graph.rooted_subgraph( - node, self.subgraph_radius - ) - negative_examples.append( - { - "brain_id": brain_id, - "subgraph": subgraph, - "xyz": subgraph.node_xyz[0], - "label": 0, - } - ) - - # Add branching nodes - negative_examples = list() - for brain_id, graph in self.graphs.items(): - # Filter branching nodes near other branching nodes - nodes = list() - for i in graph.branching_nodes(): - is_branchy = self.check_nearby_branching(brain_id, i) - if not is_branchy and graph.degree[i] == 3: - nodes.append(i) - - # Add nodes to examples - n_examples = min(len(nodes), 100) - add_examples() - - # Add non-branching points - for brain_id, graph in self.graphs.items(): - nodes = [i for i in graph.nodes if graph.degree[i] < 3] - n_examples = min(len(nodes), 100) - add_examples() - return negative_examples - - def set_examples_summary(self): - """ - Sets a summary of examples in the validation dataset. - - Returns - ------- - List[dict] - List containing example metadata stored in a dictionary. - """ - summary = list() - for example in self.examples: - summary.append( - { - "brain_id": example["brain_id"], - "xyz": example["xyz"], - "label": example["label"], - } - ) - return summary - - # --- Getters --- - def get_site(self, idx): - """ - Retrieves a merge or nonmerge site specified by the given index. - - Parameters - ---------- - idx : int - Index of site to retrieve. Positive indices correspond to merge - sites, non-positive indices correspond to non-merge sites. - - Returns - ------- - brain_id : str - Unique identifier of the brain containing the site. - node : int - Node ID of the site. - label : int - 1 if the example is positive and 0 otherwise. - """ - brain_id = self.examples[idx]["brain_id"] - subgraph = self.examples[idx]["subgraph"] - label = self.examples[idx]["label"] - return brain_id, subgraph, label - - def get_indexed_negative_site(self, idx): - """ - Gets the negative example corresponding to the given index. - - Parameters - ---------- - idx : int - Index of example. - - Returns - ------- - brain_id : str - Unique identifier for the whole-brain dataset containing the site. - node : int - Node ID of the site. - label : int - Label of example. - """ - brain_id = self.negative_examples[idx]["brain_id"] - subgraph = self.negative_examples[idx]["subgraph"] - return brain_id, subgraph, 0 - - # --- Helpers --- - def __len__(self): - """ - Gets the number of examples in the dataset. - """ - return len(self.examples) - - def get_idxs(self): - """ - Gets example indices to iterate over. - - Returns - ------- - numpy.ndarray - Example indices to iterate over. - """ - return np.arange(len(self.examples)) - - def save_summary(self, output_dir): - """ - Saves the example summary as a CSV file. - - Parameters - ---------- - output_dir : str - Path to directory that summary file is saved to. - """ - df = pd.DataFrame(self.examples_summary) - df.to_csv(os.path.join(output_dir, "val_summary.csv")) - - -# --- DataLoaders --- -class MergeSiteDataLoader(DataLoader): - """ - A custom DataLoader class that uses multithreading to read image patches - from the cloud to form batches. - """ - - def __init__( - self, - dataset, - batch_size=32, - is_multimodal=False, - modality=None, - sampler=None, - use_shuffle=True, - ): - """ - Instantiates a MergeSiteDataLoader object. - - Parameters - ---------- - dataset : MergeSiteDataset - Dataset to be iterated over to train or validate. - batch_size : int, optional - Number of examples in each batch. Default is 32. - is_multimodal : bool, optional - Indication of whether the loaded data is multimodal. Default is - False. - use_shuffle : bool, optional - Indication of whether to shuffle examples. Default is True. - """ - # Call parent class - super().__init__(dataset, batch_size=batch_size, sampler=sampler) - assert modality in [None, "graph", "pointcloud"] - - # Instance attributes - self.is_multimodal = is_multimodal - self.modality = modality - self.patches_shape = (2,) + self.dataset.patch_shape - self.use_shuffle = use_shuffle - - # --- Core Routines --- - def __iter__(self): - """ - Generates batches of examples for training and validation. - - Returns - ------- - iterator - Generates batch of examples used during training and validation. - """ - # Set indices - idxs = self.dataset.get_idxs() - if self.use_shuffle: - random.shuffle(idxs) - - # Iterate over indices - for start in range(0, len(idxs), self.batch_size): - end = min(start + self.batch_size, len(idxs)) - if self.is_multimodal and self.modality == "graph": - yield self._load_image_graph_batch(idxs[start:end]) - elif self.is_multimodal and self.modality == "pointcloud": - yield self._load_image_pc_batch(idxs[start:end]) - else: - yield self._load_image_batch(idxs[start:end]) - - def _load_image_batch(self, batch_idxs): - """ - Loads a batch of samples from the dataset using multithreading. - - Parameters - ---------- - batch_idxs : List[int] - Indices of the dataset items to include in the batch. - - Returns - ------- - patches : torch.Tensor - Image patches for the batch. - targets : torch.Tensor - Target labels corresponding to each patch. - """ - with ThreadPoolExecutor() as executor: - # Assign threads - pending = dict() - for i, idx in enumerate(batch_idxs): - thread = executor.submit(self.dataset.__getitem__, idx) - pending[thread] = i - - # Store results - patches = np.zeros((len(batch_idxs),) + self.patches_shape) - targets = np.zeros((len(batch_idxs), 1)) - for thread in as_completed(pending.keys()): - i = pending.pop(thread) - patches[i], _, targets[i] = thread.result() - return ml_util.to_tensor(patches), ml_util.to_tensor(targets) - - def _load_image_pc_batch(self, batch_idxs): - """ - Loads a batch of samples from the dataset using multithreading. - - Parameters - ---------- - batch_idxs : List[int] - Indices of the dataset items to include in the batch. - - Returns - ------- - batch : Dict[str, torch.Tensor] - Dictionary that maps modality names to batch features. - targets : torch.Tensor - Target labels corresponding to each patch. - """ - with ThreadPoolExecutor() as executor: - # Assign threads - pending = dict() - for i, idx in enumerate(batch_idxs): - thread = executor.submit(self.dataset.__getitem__, idx) - pending[thread] = i - - # Store results - patches = np.zeros((len(batch_idxs),) + self.patches_shape) - targets = np.zeros((len(batch_idxs), 1)) - point_clouds = np.zeros((len(batch_idxs), 3, 3600)) - for thread in as_completed(pending.keys()): - i = pending.pop(thread) - patches[i], subgraph, targets[i] = thread.result() - point_clouds[i] = subgraph_to_point_cloud(subgraph) - - # Set batch dictionary - batch = ml_util.TensorDict( - { - "img": ml_util.to_tensor(patches), - "point_cloud": ml_util.to_tensor(point_clouds), - } - ) - return batch, ml_util.to_tensor(targets) - - def _load_image_graph_batch(self, idxs): - """ - Loads a batch of samples from the dataset using multithreading. - - Parameters - ---------- - idxs : List[int] - Indices of the dataset items to include in the batch. - - Returns - ------- - batch : Dict[str, torch.Tensor] - Dictionary that maps modality names to batch features. - targets : torch.Tensor - Target labels corresponding to each patch. - """ - with ThreadPoolExecutor() as executor: - # Assign threads - threads = list() - for idx in idxs: - threads.append(executor.submit(self.dataset.__getitem__, idx)) - - # Store results - targets = np.zeros((len(idxs), 1)) - patches = np.zeros((len(idxs),) + self.patches_shape) - h, x, edge_index, batches = list(), list(), list(), list() - node_offset = 0 - for i, thread in enumerate(as_completed(threads)): - patches[i], subgraph, targets[i] = thread.result() - h_i, x_i, edge_index_i = subgraph_to_data(subgraph) - n_i = h_i.size(0) - - edge_index_i += node_offset - h.append(h_i) - x.append(x_i) - edge_index.append(edge_index_i) - batches.append(torch.full((n_i,), i, dtype=torch.long)) - - node_offset += n_i - - # Combine subgraph batches - h = torch.cat(h, dim=0) - x = torch.cat(x, dim=0) - edge_index = torch.cat(edge_index, dim=1) - batches = torch.cat(batches, dim=0) - - # Set batch dictionary - batch = ml_util.TensorDict( - { - "img": ml_util.to_tensor(patches), - "graph": (h, x, edge_index, batches), - } - ) - return batch, ml_util.to_tensor(targets) - - def __len__(self): - return 2 * len(self.dataset) diff --git a/src/neuron_proofreader/utils/util.py b/src/neuron_proofreader/utils/util.py index 6621b40d..4056cbfb 100644 --- a/src/neuron_proofreader/utils/util.py +++ b/src/neuron_proofreader/utils/util.py @@ -439,9 +439,7 @@ def check_gcs_prefix_exists(path): bucket_name, prefix = parse_cloud_path(path) prefix = prefix.rstrip("/") + "/" bucket = storage.Client().bucket(bucket_name) - exists = any( - bucket.list_blobs(prefix=prefix, max_results=1) - ) + exists = any(bucket.list_blobs(prefix=prefix, max_results=1)) return exists From f80ce720199f4170faff9e18245f3694519096bd Mon Sep 17 00:00:00 2001 From: anna-grim Date: Wed, 3 Jun 2026 23:56:00 +0000 Subject: [PATCH 07/11] refactor: class rebalancing, ds creation --- src/neuron_proofreader/configs.py | 3 + .../merge_proofreading/merge_datamodules.py | 144 +++++++++--------- 2 files changed, 74 insertions(+), 73 deletions(-) diff --git a/src/neuron_proofreader/configs.py b/src/neuron_proofreader/configs.py index dc6ea5c6..af5e652e 100644 --- a/src/neuron_proofreader/configs.py +++ b/src/neuron_proofreader/configs.py @@ -101,6 +101,9 @@ class ImageConfig: patch_shape: Tuple[int, int, int] = (128, 128, 128) transform: bool = False + def set_train_mode(self): + self.transform = True + @dataclass class ProposalsConfig: diff --git a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py index d34be19c..29ed1783 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py @@ -28,13 +28,13 @@ from scipy.spatial import KDTree from concurrent.futures import as_completed, ThreadPoolExecutor from torch.utils.data import Dataset, DataLoader +from tqdm import tqdm import networkx as nx import numpy as np import os import pandas as pd import queue -import random import threading import torch @@ -58,21 +58,14 @@ class BrainDataset: Parameters ---------- - anisotropy : Tuple[float] - Voxel-to-physical scaling factors. brain_id : str Unique identifier for this brain. - brightness_clip : int - Maximum raw image intensity before normalisation. - node_spacing : float - Spacing (in microns) between neighbouring graph nodes. - patch_shape : Tuple[int] - Shape of the 3D image patches to extract. subgraph_depth : float Radius (in microns) used when extracting rooted subgraphs. """ - giant_component_cable_length = 3 * 10**4 + giant_component_cable_length = 30000 + random_branching_site_probability = 0.7 def __init__( self, @@ -80,6 +73,7 @@ def __init__( img_path, sites_prefix, swcs_path, + class_ratios=(0.5, 0.5), graph_config=None, img_config=None, random_nonmerge_site_prob=0.5, @@ -88,9 +82,10 @@ def __init__( ): # Instance attributes self.brain_id = brain_id - self.rebalance_classes = rebalance_classes + self.class_ratios = class_ratios self.ignore_fragments = set() self.random_nonmerge_site_prob = random_nonmerge_site_prob + self.rebalance_classes = rebalance_classes self.subgraph_depth = subgraph_depth # Core data structures @@ -194,8 +189,7 @@ def get_site(self, idx): return self.get_random_nonmerge_site() def get_random_nonmerge_site(self): - use_branching = self.valid_branching_nodes and random.random() < 0.5 - if use_branching: + if self.use_branching(): node = util.sample_once(self.valid_branching_nodes) else: node = util.sample_once(self.graph.nodes) @@ -223,6 +217,10 @@ def add_nonmerge_sites(self, num_sites): else: self.nonmerge_sites = pd.DataFrame(new_sites) + def use_branching(self): + outcome = np.random.random() < self.random_branching_site_probability + return outcome and self.valid_branching_nodes + def _has_nearby_branching(self, root, max_depth=60): queue = [(root, 0)] visited = {root} @@ -241,21 +239,26 @@ def _has_nearby_branching(self, root, max_depth=60): return False def _list_indices(self): - # Set idxs - pos_idxs = np.arange(len(self.merge_sites)) - neg_idxs = -np.arange(len(self.nonmerge_sites)) - - # Check for class imbalance - if self.rebalance_classes: - if len(neg_idxs) < len(pos_idxs): - neg_idxs = -pos_idxs - else: - neg_idxs = np.random.choice( - neg_idxs, size=len(pos_idxs), replace=False - ) - return np.concatenate((pos_idxs, neg_idxs)) + # Compute target class counts + n_pos = len(self.merge_sites) + n_neg = len(self.nonmerge_sites) or n_pos + pos_ratio, neg_ratio = self.class_ratios + n_target_neg = min(int(n_pos * neg_ratio / pos_ratio), n_neg) + + # Check whether to rebalance negative examples + size = n_target_neg if self.rebalance_classes else n_neg + neg_idxs = np.random.choice(n_neg, size=size, replace=False) + return np.concatenate((-neg_idxs, np.arange(n_pos))) def __len__(self): + """ + Counts the number of examples in the dataset. + + Returns + ------- + int + Number of examples in the dataset. + """ return len(self._list_indices()) def __repr__(self): @@ -279,15 +282,12 @@ class BrainDatasetCollection(Dataset): Parameters ---------- - brain_datasets : List[BrainDataset] + datasets : List[BrainDataset] One BrainDataset per brain. - augmentation : callable or None - Applied to (2, D, H, W) patch arrays in-place during __getitem__. - Pass None when augmentation is not needed (e.g. validation). """ - def __init__(self, brain_datasets): - self.brain_datasets = brain_datasets + def __init__(self, datasets): + self.datasets = datasets self._index_table = self._build_index_table() def _build_index_table(self): @@ -297,21 +297,21 @@ def _build_index_table(self): Returns ------- - List[Tuple[int, int]] + table : List[Tuple[int, int]] """ table = [] - for b_idx, bd in enumerate(self.brain_datasets): + for b_idx, bd in enumerate(self.datasets): for local_idx in bd._list_indices(): table.append((b_idx, int(local_idx))) return table - # --- Dataset interface --- + # --- Dataset Interface --- def __len__(self): return len(self._index_table) def __getitem__(self, idx): """ - Returns one example: (patches, subgraph, label). + Gets one example: (patches, subgraph, label). Parameters ---------- @@ -320,12 +320,17 @@ def __getitem__(self, idx): Returns ------- - patches : numpy.ndarray shape (2, D, H, W) - subgraph : SkeletonGraph - label : int + ... """ b_idx, local_idx = self._index_table[idx] - return self.brain_datasets[b_idx][local_idx] + return self.datasets[b_idx][local_idx] + + def __repr__(self): + return ( + f"BrainDatasetCollection(" + f"n_brains={len(self.datasets)}, " + f"n_examples={len(self)})" + ) def get_idxs(self): """ @@ -338,30 +343,6 @@ def get_idxs(self): """ return np.arange(len(self._index_table)) - # --- Helpers --- - def brain_ids(self): - """Returns the list of brain IDs in this collection.""" - return [bd.brain_id for bd in self.brain_datasets] - - def n_merge_sites(self): - """Returns the total number of merge sites across all brains.""" - return sum(len(bd.merge_sites) for bd in self.brain_datasets) - - def count_fragments(self): - """Returns the total number of fragments across all brains.""" - return sum( - nx.number_connected_components(bd.graph) - for bd in self.brain_datasets - if bd.graph is not None - ) - - def __repr__(self): - return ( - f"BrainDatasetCollection(" - f"n_brains={len(self.brain_datasets)}, " - f"n_examples={len(self)})" - ) - # --- Dataloader --- class ThreadedDataLoader(DataLoader): @@ -393,9 +374,7 @@ def __init__( self.modality = modality self.use_shuffle = use_shuffle self.prefetch_batches = prefetch_batches - self.patches_shape = (2,) + dataset.brain_datasets[ - 0 - ].patch_loader.patch_shape + self.img_shape = (2,) + dataset.datasets[0].patch_loader.patch_shape # Set batch loader if self.is_multimodal and self.modality == "graph": @@ -468,7 +447,7 @@ def _load_image_batch(self, batch_idxs): pending[thread] = i # Store results - patches = np.zeros((len(batch_idxs),) + self.patches_shape) + patches = np.zeros((len(batch_idxs),) + self.img_shape) targets = np.zeros((len(batch_idxs), 1)) for thread in as_completed(pending.keys()): i = pending.pop(thread) @@ -499,7 +478,7 @@ def _load_image_pc_batch(self, batch_idxs): pending[thread] = i # Store results - patches = np.zeros((len(batch_idxs),) + self.patches_shape) + patches = np.zeros((len(batch_idxs),) + self.shape) targets = np.zeros((len(batch_idxs), 1)) point_clouds = np.zeros((len(batch_idxs), 3, 3600)) for thread in as_completed(pending.keys()): @@ -540,7 +519,7 @@ def _load_image_graph_batch(self, idxs): # Store results targets = np.zeros((len(idxs), 1)) - patches = np.zeros((len(idxs),) + self.patches_shape) + patches = np.zeros((len(idxs),) + self.img_shape) h, x, edge_index, batches = list(), list(), list(), list() node_offset = 0 for i, thread in enumerate(as_completed(threads)): @@ -575,22 +554,32 @@ def _load_image_graph_batch(self, idxs): # --- Sites Loading --- def create_dataset_collection( brain_ids, + dataset_mode, img_prefixes_path, sites_root_path, swcs_root_path, + class_ratios=(0.5, 0.5), graph_config=None, img_config=None, - random_nonmerge_site_prob=0.5, - rebalance_classes=False, subgraph_depth=100, ): + # Set parameters based on mode + assert dataset_mode in ["Train", "Val"] + if dataset_mode == "Train": + img_config.set_train_mode() + random_nonmerge_site_prob = 0.5 + rebalance_classes = True + else: + random_nonmerge_site_prob = 0 + rebalance_classes = False + # Load image prefixes bucket, root_prefix = util.parse_cloud_path(sites_root_path) img_prefixes = util.read_json(img_prefixes_path) # Iterate over brains datasets = list() - for brain_id in brain_ids: + for brain_id in tqdm(brain_ids, desc=f"Load {dataset_mode} Dataset"): # Extract dataset info img_path = os.path.join(img_prefixes[brain_id], "0") segmentation_id = get_segmentation_id(sites_root_path, brain_id) @@ -605,12 +594,21 @@ def create_dataset_collection( img_path, sites_path, swcs_path, + class_ratios=class_ratios, graph_config=graph_config, img_config=img_config, subgraph_depth=subgraph_depth, random_nonmerge_site_prob=random_nonmerge_site_prob, rebalance_classes=rebalance_classes, ) + + # Check whether to generate examples for validation + if dataset_mode == "Val": + num_target_neg = 10 * len(dataset.merge_sites) + num_added_neg = num_target_neg - len(dataset.nonmerge_sites) + dataset.add_nonmerge_sites(num_added_neg) + + # Add dataset to collection datasets.append(dataset) return BrainDatasetCollection(datasets) From 56fb83c93c8d4260300a6565f716e3057c9a3632 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 4 Jun 2026 18:21:35 +0000 Subject: [PATCH 08/11] refactor: improved swc and graph loader --- src/neuron_proofreader/configs.py | 1 + .../merge_proofreading/merge_datamodules.py | 11 +- src/neuron_proofreader/skeleton_graph.py | 2 + src/neuron_proofreader/utils/graph_util.py | 166 ++++++------------ src/neuron_proofreader/utils/swc_util.py | 84 +++------ 5 files changed, 86 insertions(+), 178 deletions(-) diff --git a/src/neuron_proofreader/configs.py b/src/neuron_proofreader/configs.py index af5e652e..0fca1782 100644 --- a/src/neuron_proofreader/configs.py +++ b/src/neuron_proofreader/configs.py @@ -70,6 +70,7 @@ class GraphConfig: anisotropy: Tuple[float, float, float] = (1.0, 1.0, 1.0) min_cable_length: float = 0.0 + min_swc_pts: int = 1 name: str = "graph_config" node_spacing: float = 1.0 prune_depth: float = 20.0 diff --git a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py index 29ed1783..04d61aa2 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py @@ -17,8 +17,7 @@ BrainDatasetCollection Holds an ordered list of BrainDataset objects. Routes global indices to - the correct brain, exposes split() for train/val partitioning, and is - the object handed to MergeSiteDataLoader. + the correct brain and is the object handed to ThreadedDataLoader. ThreadedDataLoader Custom DataLoader that uses multithreading to fetch image patches from @@ -28,7 +27,6 @@ from scipy.spatial import KDTree from concurrent.futures import as_completed, ThreadPoolExecutor from torch.utils.data import Dataset, DataLoader -from tqdm import tqdm import networkx as nx import numpy as np @@ -107,6 +105,7 @@ def load_fragments(self, config, swcs_path): graph = SkeletonGraph( anisotropy=config.anisotropy, min_cable_length=config.min_cable_length, + min_swc_pts=config.min_swc_pts, node_spacing=config.node_spacing, use_anisotropy=config.use_anisotropy, verbose=config.verbose, @@ -344,7 +343,7 @@ def get_idxs(self): return np.arange(len(self._index_table)) -# --- Dataloader --- +# --- DataLoader --- class ThreadedDataLoader(DataLoader): _VALID_MODALITIES = {None, "graph", "pointcloud"} @@ -564,6 +563,7 @@ def create_dataset_collection( subgraph_depth=100, ): # Set parameters based on mode + print(f"Load {dataset_mode} Dataset") assert dataset_mode in ["Train", "Val"] if dataset_mode == "Train": img_config.set_train_mode() @@ -579,7 +579,7 @@ def create_dataset_collection( # Iterate over brains datasets = list() - for brain_id in tqdm(brain_ids, desc=f"Load {dataset_mode} Dataset"): + for i, brain_id in enumerate(brain_ids, start=1): # Extract dataset info img_path = os.path.join(img_prefixes[brain_id], "0") segmentation_id = get_segmentation_id(sites_root_path, brain_id) @@ -589,6 +589,7 @@ def create_dataset_collection( ) # Add dataset + print(f" \nBrain ID [{i}/{len(brain_ids)}]: {brain_id}") dataset = BrainDataset( brain_id, img_path, diff --git a/src/neuron_proofreader/skeleton_graph.py b/src/neuron_proofreader/skeleton_graph.py index eb75f46d..ca47684f 100644 --- a/src/neuron_proofreader/skeleton_graph.py +++ b/src/neuron_proofreader/skeleton_graph.py @@ -46,6 +46,7 @@ def __init__( self, anisotropy=(1.0, 1.0, 1.0), min_cable_length=0, + min_swc_pts=1, node_spacing=1, prune_depth=20, use_anisotropy=True, @@ -90,6 +91,7 @@ def __init__( self.graph_loader = graph_util.GraphLoader( anisotropy=anisotropy, min_cable_length=min_cable_length, + min_swc_pts=min_swc_pts, node_spacing=node_spacing, prune_depth=prune_depth, verbose=verbose, diff --git a/src/neuron_proofreader/utils/graph_util.py b/src/neuron_proofreader/utils/graph_util.py index 2fbdf4dd..4ad03abb 100644 --- a/src/neuron_proofreader/utils/graph_util.py +++ b/src/neuron_proofreader/utils/graph_util.py @@ -15,13 +15,12 @@ ProcessPoolExecutor, wait, ) -from scipy.spatial.distance import euclidean from tqdm import tqdm import networkx as nx import numpy as np -from neuron_proofreader.utils import geometry_util as geometry, swc_util, util +from neuron_proofreader.utils import geometry_util as geometry, swc_util class GraphLoader: @@ -34,6 +33,7 @@ def __init__( self, anisotropy=(1.0, 1.0, 1.0), min_cable_length=40.0, + min_swc_pts=1, node_spacing=1, prefetch=128, prune_depth=24.0, @@ -64,6 +64,7 @@ def __init__( """ # Instance attributes self.min_cable_length = min_cable_length + self.min_swc_pts = min_swc_pts self.node_spacing = node_spacing self.prefetch = prefetch self.prune_depth = prune_depth @@ -88,13 +89,16 @@ def __call__(self, swc_pointer): """ # Read SWC files swc_dicts = self.swc_reader(swc_pointer) + swc_dicts = deque( + d for d in swc_dicts if len(d["xyz"]) > self.min_swc_pts + ) if self.verbose: pbar = tqdm(total=len(swc_dicts), desc="Load Graphs") # Load graphs + pending = set() with ProcessPoolExecutor() as executor: # Start processes - pending = set() while len(pending) < self.prefetch and swc_dicts: pending.add(executor.submit(self.load, swc_dicts.pop())) @@ -107,13 +111,14 @@ def __call__(self, swc_pointer): result = future.result() if result: irreducibles.append(result) - pbar.update(1) if self.verbose else None + + # Update progress bar + if self.verbose: + pbar.update(1) # Continue submitting processes if swc_dicts: - pending.add( - executor.submit(self.load, swc_dicts.pop()) - ) + pending.add(executor.submit(self.load, swc_dicts.pop())) return irreducibles def load(self, swc_dict): @@ -137,11 +142,10 @@ def load(self, swc_dict): prune_branches(graph, self.prune_depth) # Extract irreducible components (if applicable) - if self.satisfies_cable_length_condition(graph): - irreducibles = self.get_irreducibles(graph) - if irreducibles: - irreducibles["is_soma"] = len(swc_dict["soma_nodes"]) > 0 - irreducibles["swc_id"] = swc_dict["swc_name"] + irreducibles = self.get_irreducibles(graph) + if irreducibles: + irreducibles["is_soma"] = len(swc_dict["soma_nodes"]) > 0 + irreducibles["swc_id"] = swc_dict["swc_name"] return irreducibles else: return None @@ -178,12 +182,15 @@ def dist(i, j): float Distance between nodes. """ - return euclidean(graph.graph["xyz"][i], graph.graph["xyz"][j]) + return np.linalg.norm(xyz[i] - xyz[j]) # Initializations leaf = find_leaf(graph) - irreducible_nodes = {leaf} - irreducible_edges = dict() + irr_nodes = {leaf} + irr_edges = dict() + + radius = graph.graph["radius"] + xyz = graph.graph["xyz"] # Main root, cable_length = None, 0 @@ -191,68 +198,43 @@ def dist(i, j): # Check for start of irreducible edge if root is None: root, edge_length = i, 0 - attrs = { - "radius": [graph.graph["radius"][i]], - "xyz": [graph.graph["xyz"][i]], - } + attrs = {"radius": [radius[i]], "xyz": [xyz[i]]} # Visit node - edge_length += dist(i, j) - attrs["radius"].append(graph.graph["radius"][j]) - attrs["xyz"].append(graph.graph["xyz"][j]) + edge_length += np.linalg.norm(xyz[i] - xyz[j]) + attrs["radius"].append(radius[j]) + attrs["xyz"].append(xyz[j]) # Check for end of irreducible edge if graph.degree[j] != 2: cable_length += edge_length - irreducible_nodes.add(j) + irr_nodes.add(j) attrs = to_numpy(attrs) n_pts = int(edge_length / self.node_spacing) self.resample_curve_3d(graph, attrs, (root, j), n_pts) - irreducible_edges[(root, j)] = attrs + irr_edges[(root, j)] = attrs root = None # Check for curvy line fragment - if len(irreducible_nodes) == 2: - endpoint_dist = dist(*tuple(irreducible_nodes)) + if len(irr_nodes) == 2: + t0, t1 = irr_nodes + endpoint_dist = np.linalg.norm(xyz[t0] - xyz[t1]) if endpoint_dist / cable_length < 0.5: return None # Store results - if cable_length > self.min_cable_length: + if cable_length >= self.min_cable_length: irreducibles = { - "nodes": set_node_attrs(graph, irreducible_nodes), - "edges": set_edge_attrs(graph, irreducible_edges), + "nodes": set_node_attrs(graph, irr_nodes), + "edges": set_edge_attrs(graph, irr_edges), } else: irreducibles = None return irreducibles # --- Helpers --- - def satisfies_cable_length_condition(self, graph): - """ - Determines whether the cable length of the given graph is greater - than "self.min_cable_length". - - Parameters - ---------- - graph : networkx.Graph - Graph to be checked. - - Returns - ------- - bool - Indication of whether the total cable length of the given graph is - greater than "self.min_cable_length". - """ - length = 0 - for i, j in nx.dfs_edges(graph): - length += euclidean(graph.graph["xyz"][i], graph.graph["xyz"][j]) - if length > self.min_cable_length: - return True - return False - def resample_curve_3d(self, graph, attrs, edge, n_pts): """ Smooths a 3D curve and update the corresponding edge endpoints in the @@ -294,13 +276,8 @@ def set_node_attrs(graph, nodes): Dictionary where the keys are node ids and values are dictionaries containing the "radius" and "xyz" attributes of the nodes. """ - attrs = dict() - for i in nodes: - attrs[i] = { - "radius": graph.graph["radius"][i], - "xyz": graph.graph["xyz"][i], - } - return attrs + xyz, radius = graph.graph["xyz"], graph.graph["radius"] + return {i: {"radius": radius[i], "xyz": xyz[i]} for i in nodes} def set_edge_attrs(graph, attrs): @@ -322,16 +299,16 @@ def set_edge_attrs(graph, attrs): Updated edge attribute dictionary. """ for e in attrs: - i, j = tuple(e) + i, j = e attrs[e]["xyz"][0] = graph.graph["xyz"][i] attrs[e]["xyz"][-1] = graph.graph["xyz"][j] return attrs # --- Miscellaneous --- -def count_nodes(irreducibles): +def count_nodes(irr_list): n = 0 - for irr in irreducibles: + for irr in irr_list: n += len(irr["nodes"]) for attrs in irr["edges"].values(): n += len(attrs["xyz"]) - 2 @@ -395,39 +372,7 @@ def find_leaf(graph): for i in graph.nodes: if graph.degree[i] == 1: return i - - -def largest_components(graph, k): - """ - Finds the "k" largest connected components in "graph". - - Parameters - ---------- - graph : nx.Graph - Graph to be searched. - k : int - Number of largest connected components to return. - - Returns - ------- - node_ids : List[int] - List where each entry is a random node from one of the k largest - connected components. - """ - component_cardinalities = k * [-1] - node_ids = k * [-1] - for nodes in nx.connected_components(graph): - if len(nodes) > component_cardinalities[-1]: - i = 0 - while i < k: - if len(nodes) > component_cardinalities[i]: - component_cardinalities.insert(i, len(nodes)) - component_cardinalities.pop(-1) - node_ids.insert(i, util.sample_singleton(nodes)) - node_ids.pop(-1) - break - i += 1 - return node_ids + return None def prune_branches(graph, depth): @@ -441,21 +386,22 @@ def prune_branches(graph, depth): depth : float Length of branches that are pruned. """ - for leaf in [i for i in graph.nodes if graph.degree[i] == 1]: - branch = [leaf] - length = 0 - for i, j in nx.dfs_edges(graph, source=leaf): - # Visit edge - length += euclidean(graph.graph["xyz"][i], graph.graph["xyz"][j]) - if length > depth: - break - - # Check whether to continue search - if graph.degree(j) == 2: - branch.append(j) - elif graph.degree(j) > 2: - graph.remove_nodes_from(branch) - break + xyz = graph.graph["xyz"] + changed = True + while changed: + changed = False + for leaf in [i for i in graph.nodes if graph.degree[i] == 1]: + branch, length = [leaf], 0 + for i, j in nx.dfs_edges(graph, source=leaf): + length += np.linalg.norm(xyz[i] - xyz[j]) + if length > depth: + break + if graph.degree(j) == 2: + branch.append(j) + elif graph.degree(j) > 2: + graph.remove_nodes_from(branch) + changed = True + break def to_numpy(attrs): diff --git a/src/neuron_proofreader/utils/swc_util.py b/src/neuron_proofreader/utils/swc_util.py index c57c2ad0..4ded475d 100644 --- a/src/neuron_proofreader/utils/swc_util.py +++ b/src/neuron_proofreader/utils/swc_util.py @@ -26,7 +26,7 @@ ThreadPoolExecutor, as_completed, ) -from google.auth.exceptions import RefreshError, TransportError +from google.auth.exceptions import TransportError from google.cloud import storage from io import BytesIO, StringIO from tqdm import tqdm @@ -206,7 +206,7 @@ def read_zips(self, zip_paths, read_fn): for process in as_completed(futures): try: swc_dicts.extend(process.result()) - except (TransportError, RefreshError): + except Exception: pass if self.verbose: @@ -280,7 +280,7 @@ def read_from_cloud(self, path): Dictionaries whose keys and values are the attribute names and values from an SWC file. """ - # Extact info + # Extract info assert util.is_s3_path(path) or util.is_gcs_path(path) use_s3 = util.is_s3_path(path) @@ -313,13 +313,13 @@ def read_gcs_swc(self, path): values from an SWC file. """ # Initialize cloud reader - bucket_name, subpath = util.parse_cloud_path(path) + bucket_name, key = util.parse_cloud_path(path) bucket = storage.Client().bucket(bucket_name) - blob = bucket.blob(subpath) + blob = bucket.blob(key) # Parse swc contents content = blob.download_as_text().splitlines() - filename = os.path.basename(subpath) + filename = os.path.basename(key) return self.parse(content, filename) def read_gcs_zip(self, path): @@ -338,33 +338,14 @@ def read_gcs_zip(self, path): Dictionaries whose keys and values are the attribute names and values from an SWC file. """ - # Download ZIP - bucket_name, path = util.parse_cloud_path(path) + bucket_name, key = util.parse_cloud_path(path) bucket = storage.Client().bucket(bucket_name) try: - zip_content = bucket.blob(path).download_as_bytes() + zip_content = bucket.blob(key).download_as_bytes() except TransportError: print(f"Failed to read {path}!") return deque() - - # Parse ZIP - swc_dicts = deque() - zip_content = bucket.blob(path).download_as_bytes() - with ZipFile(BytesIO(zip_content), "r") as zf: - with ThreadPoolExecutor() as executor: - # Assign threads - threads = set() - for name in zf.namelist(): - threads.add( - executor.submit(self.read_zipped_swc, zf, name) - ) - - # Process results - for thread in as_completed(threads): - result = thread.result() - if result: - swc_dicts.append(result) - return swc_dicts + return self._parse_zip_bytes(zip_content) def read_s3_zip(self, path): """ @@ -382,46 +363,24 @@ def read_s3_zip(self, path): Dictionaries whose keys and values are the attribute names and values from an SWC file. """ - # Initialize cloud reader bucket, key = util.parse_cloud_path(path) s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED)) zip_content = s3.get_object(Bucket=bucket, Key=key)["Body"].read() + return self._parse_zip_bytes(zip_content) - # Parse ZIP + def _parse_zip_bytes(self, zip_content): with ZipFile(BytesIO(zip_content), "r") as zf: + names = [f for f in zf.namelist() if f.endswith(".swc")] with ThreadPoolExecutor() as executor: - # Assign threads - threads = set() - for name in zf.namelist(): - threads.add( - executor.submit(self.read_zipped_swc, zf, name) - ) - - # Store results - swc_dicts = deque() - for thread in as_completed(threads): - result = thread.result() - if result: - swc_dicts.append(result) - return swc_dicts + threads = { + executor.submit(self.read_zipped_swc, zf, name) + for name in names + } + return deque( + t.result() for t in as_completed(threads) if t.result() + ) # -- Process Text --- - def iterator(self, iterator): - """ - Gets an iterator that optionally displays a progress bar. - - Parameters - ---------- - iterator : iterable - Object to be iterated over. - - Returns - ------- - tqdm.tqdm - Iterator that is optionally wrapped in a progress bar. - """ - return tqdm(iterator, desc="Read SWCs") if self.verbose else iterator - def manual_progress_bar(self, total): """ Gets progress bar that needs to be updated manually. @@ -509,6 +468,7 @@ def process_content(self, content): offset = self.read_coordinate(parts[2:5]) if not line.startswith("#") and len(line.strip()) > 0: return content[i:], offset + return [], offset def read_coordinate(self, xyz_str, offset=(0, 0, 0)): """ @@ -627,9 +587,7 @@ def get_swc_name(path): name : str Name of the SWC file, minus the extension. """ - filename = os.path.basename(path) - name, ext = os.path.splitext(filename) - return name + return os.path.splitext(os.path.basename(path))[0] def to_graph(swc_dict): From 3fccc71389763d01931513536fdb53ae0ec12b0a Mon Sep 17 00:00:00 2001 From: anna-grim Date: Fri, 5 Jun 2026 19:03:40 +0000 Subject: [PATCH 09/11] bug: config subclassing --- src/neuron_proofreader/configs.py | 6 ++--- .../machine_learning/train.py | 19 +++++++++++++- .../merge_proofreading/merge_datamodules.py | 9 ++++--- src/neuron_proofreader/utils/graph_util.py | 26 ++++--------------- src/neuron_proofreader/utils/swc_util.py | 17 +++++++----- 5 files changed, 41 insertions(+), 36 deletions(-) diff --git a/src/neuron_proofreader/configs.py b/src/neuron_proofreader/configs.py index 0fca1782..90be82ef 100644 --- a/src/neuron_proofreader/configs.py +++ b/src/neuron_proofreader/configs.py @@ -47,7 +47,7 @@ def save(self, output_dir): @dataclass -class GraphConfig: +class GraphConfig(Config): """ Configuration class for skeleton graph parameters. @@ -80,7 +80,7 @@ class GraphConfig: @dataclass -class ImageConfig: +class ImageConfig(Config): """ Configuration class for image processing parameters. @@ -107,7 +107,7 @@ def set_train_mode(self): @dataclass -class ProposalsConfig: +class ProposalsConfig(Config): """ Configuration class for skeleton graph parameters. diff --git a/src/neuron_proofreader/machine_learning/train.py b/src/neuron_proofreader/machine_learning/train.py index 20edafb4..e53526f1 100644 --- a/src/neuron_proofreader/machine_learning/train.py +++ b/src/neuron_proofreader/machine_learning/train.py @@ -16,6 +16,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.tensorboard import SummaryWriter from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm import numpy as np import os @@ -69,6 +70,7 @@ def __init__( max_epochs=200, min_recall=0, save_mistake_mips=False, + verbose=False, ): """ Instantiates a Trainer object. @@ -105,6 +107,7 @@ def __init__( self.mistakes_dir = os.path.join(log_dir, "mistakes") self.model_name = model_name self.save_mistake_mips = save_mistake_mips + self.verbose = verbose self.criterion = nn.BCEWithLogitsLoss() self.model = model.to(device) @@ -129,12 +132,13 @@ def run(self, train_dataloader, val_dataloader): print("\nExperiment:", exp_name) for epoch in range(self.max_epochs): # Train-Validate + print("\nEpoch", epoch) train_stats = self.train_step(train_dataloader, epoch) val_stats = self.validate_step(val_dataloader, epoch) new_best = self.check_model_performance(val_stats, epoch) # Report reuslts - print(f"\nEpoch {epoch}: " + ("New Best!" if new_best else " ")) + print(f"Results: " + ("New Best!" if new_best else " ")) self.report_stats(train_stats, is_train=True) self.report_stats(val_stats, is_train=False) @@ -157,6 +161,11 @@ def train_step(self, dataloader, epoch): stats : Dict[str, float] Dictionary of aggregated training metrics. """ + # Create progress bar (if applicable) + if self.verbose: + pbar = tqdm(total=len(dataloader), desc="Train") + + # Train for an epoch self.model.train() loss, y, hat_y = list(), list(), list() for x_i, y_i in dataloader: @@ -178,6 +187,10 @@ def train_step(self, dataloader, epoch): hat_y.extend(ml_util.to_cpu(hat_y_i, True).flatten().tolist()) loss.append(float(ml_util.to_cpu(loss_i))) + # Update progress bar + if self.verbose: + pbar.update(1) + # Write stats to tensorboard stats = self.compute_stats(y, hat_y) stats["loss"] = np.mean(loss) @@ -202,6 +215,10 @@ def validate_step(self, dataloader, epoch): is_best : bool True if the current F1 score is the best so far. """ + # Create progress bar (if applicable) + if self.verbose: + pbar = tqdm(total=len(dataloader), desc="Val") + # Initializations idx_offset = 0 loss_accum = 0 diff --git a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py index 04d61aa2..488463df 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py @@ -356,7 +356,7 @@ def __init__( modality=None, sampler=None, use_shuffle=True, - prefetch_batches=8, + prefetch=8, ): # Check that modality is valid if modality not in self._VALID_MODALITIES: @@ -372,7 +372,7 @@ def __init__( self.is_multimodal = is_multimodal self.modality = modality self.use_shuffle = use_shuffle - self.prefetch_batches = prefetch_batches + self.prefetch = prefetch self.img_shape = (2,) + dataset.datasets[0].patch_loader.patch_shape # Set batch loader @@ -398,7 +398,7 @@ def __iter__(self): # Sentinel signalling the prefetch thread is done _DONE = object() - buffer = queue.Queue(maxsize=self.prefetch_batches) + buffer = queue.Queue(maxsize=self.prefetch) def prefetch_worker(): try: @@ -563,7 +563,7 @@ def create_dataset_collection( subgraph_depth=100, ): # Set parameters based on mode - print(f"Load {dataset_mode} Dataset") + print(f"\nLoading {dataset_mode} Dataset...") assert dataset_mode in ["Train", "Val"] if dataset_mode == "Train": img_config.set_train_mode() @@ -602,6 +602,7 @@ def create_dataset_collection( random_nonmerge_site_prob=random_nonmerge_site_prob, rebalance_classes=rebalance_classes, ) + print(dataset) # Check whether to generate examples for validation if dataset_mode == "Val": diff --git a/src/neuron_proofreader/utils/graph_util.py b/src/neuron_proofreader/utils/graph_util.py index 4ad03abb..dad3464b 100644 --- a/src/neuron_proofreader/utils/graph_util.py +++ b/src/neuron_proofreader/utils/graph_util.py @@ -51,6 +51,8 @@ def __init__( min_cable_length : float, optional Minimum cable length (in microns) of SWC files that are loaded. Default is 40. + min_swc_pts : int, optional + ... node_spacing : int, optional Spacing (in microns) between neighboring nodes. Default is 1. prefetch : int, optional @@ -68,7 +70,7 @@ def __init__( self.node_spacing = node_spacing self.prefetch = prefetch self.prune_depth = prune_depth - self.swc_reader = swc_util.Reader(anisotropy, verbose) + self.swc_reader = swc_util.Reader(anisotropy, min_swc_pts, verbose) self.verbose = verbose def __call__(self, swc_pointer): @@ -165,25 +167,6 @@ def get_irreducibles(self, graph): Dictionary containing the irreducible components of a connected graph. """ - - def dist(i, j): - """ - Computes distance between the given nodes. - - Parameters - ---------- - i : int - Node ID. - j : int - Node ID. - - Returns - ------- - float - Distance between nodes. - """ - return np.linalg.norm(xyz[i] - xyz[j]) - # Initializations leaf = find_leaf(graph) irr_nodes = {leaf} @@ -377,7 +360,8 @@ def find_leaf(graph): def prune_branches(graph, depth): """ - Prunes branches with length less than "depth" microns. + Prunes paths between leaf and branching nodes with cable length less than + "depth" microns. Parameters ---------- diff --git a/src/neuron_proofreader/utils/swc_util.py b/src/neuron_proofreader/utils/swc_util.py index 4ded475d..00e828fa 100644 --- a/src/neuron_proofreader/utils/swc_util.py +++ b/src/neuron_proofreader/utils/swc_util.py @@ -47,19 +47,24 @@ class Reader: archive, and (3) local directory of ZIP archives. """ - def __init__(self, anisotropy=(1.0, 1.0, 1.0), verbose=True): + def __init__( + self, anisotropy=(1.0, 1.0, 1.0), min_swc_pts=1, verbose=True + ): """ - Initializes a Reader object that reads SWC files. + Instantiates a Reader object for reading SWC files. Parameters ---------- anisotropy : Tuple[float], optional Image to physical coordinates scaling factors to account for the anisotropy of the microscope. Default is (1.0, 1.0, 1.0). + min_swc_pts : int, optional + ... verbose : bool, optional Indication of whether to display a progress bar. Default is True. """ self.anisotropy = anisotropy + self.min_swc_pts = min_swc_pts self.verbose = verbose # --- Read Data --- @@ -164,13 +169,11 @@ def read_swcs(self, swc_paths): """ with ThreadPoolExecutor() as executor: # Assign threads - threads = set() - for path in swc_paths: - threads.add(executor.submit(self.read_swc, path)) + threads = {executor.submit(self.read_swc, p) for p in swc_paths} + pbar = self.manual_progress_bar(len(threads)) # Store results swc_dicts = deque() - pbar = self.manual_progress_bar(len(threads)) for thread in as_completed(threads): result = thread.result() if result: @@ -415,7 +418,7 @@ def parse(self, content, filename): # Initializations swc_name, _ = os.path.splitext(filename) content, offset = self.process_content(content) - if len(content) > 0: + if len(content) >= self.min_swc_pts: swc_dict = { "id": np.zeros((len(content)), dtype=int), "pid": np.zeros((len(content)), dtype=int), From 57d205a376c0bc90c744d2c3c7da5dbde91ec704 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Mon, 8 Jun 2026 18:28:59 +0000 Subject: [PATCH 10/11] refactor: random site generation --- .../machine_learning/image_dataloader.py | 2 +- .../machine_learning/train.py | 7 +- .../merge_proofreading/merge_dataloading.py | 217 ------------------ .../merge_proofreading/merge_datamodules.py | 152 +++++++----- 4 files changed, 107 insertions(+), 271 deletions(-) delete mode 100644 src/neuron_proofreader/merge_proofreading/merge_dataloading.py diff --git a/src/neuron_proofreader/machine_learning/image_dataloader.py b/src/neuron_proofreader/machine_learning/image_dataloader.py index a9035b0b..a0c1834d 100644 --- a/src/neuron_proofreader/machine_learning/image_dataloader.py +++ b/src/neuron_proofreader/machine_learning/image_dataloader.py @@ -234,7 +234,7 @@ def create_mask(self, center, shape, node): # Annotate mask mask = np.zeros(shape) - self.annotate_foreground(mask, nodes, offset, fill_val=0.5) + #self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP self.annotate_fragment(mask, subgraph, offset, fill_val=1) return mask diff --git a/src/neuron_proofreader/machine_learning/train.py b/src/neuron_proofreader/machine_learning/train.py index e53526f1..e0cca487 100644 --- a/src/neuron_proofreader/machine_learning/train.py +++ b/src/neuron_proofreader/machine_learning/train.py @@ -129,6 +129,7 @@ def run(self, train_dataloader, val_dataloader): Dataloader used for validation. """ exp_name = os.path.basename(os.path.normpath(self.log_dir)) + val_dataloader.dataset.save_val_summary(self.log_dir) print("\nExperiment:", exp_name) for epoch in range(self.max_epochs): # Train-Validate @@ -138,7 +139,7 @@ def run(self, train_dataloader, val_dataloader): new_best = self.check_model_performance(val_stats, epoch) # Report reuslts - print(f"Results: " + ("New Best!" if new_best else " ")) + print("Results: " + ("New Best!" if new_best else " ")) self.report_stats(train_stats, is_train=True) self.report_stats(val_stats, is_train=False) @@ -247,6 +248,10 @@ def validate_step(self, dataloader, epoch): self._save_mistake_mips(x, y, hat_y, idx_offset) idx_offset += len(y) + # Update progress bar + if self.verbose: + pbar.update(1) + # Write stats to tensorboard stats = self.compute_stats(y_accum, hat_y_accum) stats["loss"] = loss_accum / len(y_accum) diff --git a/src/neuron_proofreader/merge_proofreading/merge_dataloading.py b/src/neuron_proofreader/merge_proofreading/merge_dataloading.py deleted file mode 100644 index d9c6fd50..00000000 --- a/src/neuron_proofreader/merge_proofreading/merge_dataloading.py +++ /dev/null @@ -1,217 +0,0 @@ -""" -Created on Thu Nov 13 11:00:00 2025 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Code for loading merge site dataset. - -""" - -import ast -import numpy as np -import os -import pandas as pd - -from neuron_proofreader.utils import util - -TEST_BRAIN = "747807" - - -# --- Load Skeletons --- -def load_fragments(dataset, is_test=False): - """ - Loads neuron fragments for a selected set of merge-site indices into - dataset. - - Parameters - ---------- - dataset : MergeSiteDataset - Dataset that fragments are loaded into. - is_test : bool, optional - Indication of whether this is a test run so only fragments from a - single brain should be loaded. Default is False. - """ - # Initializations - merge_sites_df = dataset.merge_sites_df - target_pairs = get_brain_segmentation_pairs(merge_sites_df) - root = "gs://allen-nd-goog/automated_proofreading_dataset/raw_merge_sites" - - # Main - print("\nLoading Fragments") - for brain_id in get_brain_ids(merge_sites_df, is_test): - sub_df = merge_sites_df.loc[merge_sites_df["brain_id"] == brain_id] - for segmentation_id in sub_df["segmentation_id"].unique(): - if (brain_id, segmentation_id) in target_pairs: - swc_pointer = ( - f"{root}/{brain_id}/{segmentation_id}/merged_fragments.zip" - ) - dataset.load_fragment_graphs( - brain_id, swc_pointer, use_anisotropy=False - ) - - -def load_groundtruth(dataset, is_test=False): - """ - Loads ground truth skeletons into dataset. - - Parameters - ---------- - dataset : MergeSiteDataset - Dataset that fragments are loaded into. - is_test : bool, optional - Indication of whether this is a test run so only fragments from a - single brain should be loaded. Default is False. - """ - print("\nLoading Ground Truth") - root = "gs://allen-nd-goog/ground_truth_tracings" - for brain_id in get_brain_ids(dataset.merge_sites_df, is_test): - swc_pointer = f"{root}/{brain_id}/voxel" - dataset.load_gt_graphs(brain_id, swc_pointer) - - -def load_images( - dataset, img_prefixes_path, segmentation_prefixes_path, is_test=False -): - """ - Loads images into dataset. - - Parameters - ---------- - dataset : MergeSiteDataset - Dataset that fragments are loaded into. - img_prefixes_path : str, optional - Path to json that maps brain IDs to S3 image paths. Default is None. - segmentation_prefixes_path : str, optional - Path to json that maps brain IDs to segmentation paths. Default is - None. - is_test : bool, optional - Indication of whether this is a test run so only fragments from a - single brain should be loaded. Default is False. - """ - img_prefixes = util.read_json(img_prefixes_path) - segmentation_prefixes = util.read_json(segmentation_prefixes_path) - for brain_id in get_brain_ids(dataset.merge_sites_df, is_test): - img_path = os.path.join(img_prefixes[brain_id], "0") - segmentation_path = segmentation_prefixes[brain_id] - dataset.load_images(brain_id, img_path, segmentation_path) - - -# --- Process Merge Site DataFrame --- -def get_brain_segmentation_pairs(merge_sites_df): - """ - Extracts unique (brain_id, segmentation_id) pairs from a merge sites - dataframe. - - Parameters - ---------- - merge_sites_df : pandas.DataFrame - DataFrame containing merge site information. Must have columns: - - 'brain_id' : unique identifier of the whole-brain dataset - - 'segmentation_id' : unique identifier of the segmentation - - Returns - ------- - brain_segmentation_pairs : Set[Tuple[str]] - Unique (brain_id, segmentation_id) pairs from a merge sites dataframe. - """ - pairs = set() - for i in range(len(merge_sites_df)): - brain_id = merge_sites_df["brain_id"][i] - segmentation_id = merge_sites_df["segmentation_id"][i] - pairs.add((brain_id, segmentation_id)) - return pairs - - -def get_brain_merge_sites(merge_sites_df, brain_id): - """ - Gets the xyz coordinates of ground truth merge sites for a given brain. - - Parameters - ---------- - merge_sites_df : pandas.DataFrame - DataFrame containing merge sites, must contain the columns: - "brain_id", "segmentation_id", "segment_id", and "xyz". - brain_id : str - Unique identifier for a whole-brain dataset. - - Returns - ------- - numpy.ndarray - Ground-truth merge sites (xyz coordinates) for a given brain. - """ - idx_mask = merge_sites_df["brain_id"] == brain_id - return np.array(merge_sites_df.loc[idx_mask, "xyz"].tolist()) - - -def load_merge_sites_df(path, is_test=False): - """ - Loads a merge sites dataframe from a CSV file and process its columns. - - Parameters - ---------- - path : str - Path to the CSV file containing merge site data. The CSV must include - the columns: 'brain_id', 'segment_id', and 'xyz'. - is_test : bool, optional - Indication of whether this is a test run so only sites from a single - brain should be loaded. Default is False. - - Returns - ------- - merge_sites_df : pandas.DataFrame - Processed dataframe with the following modifications: - - 'brain_id' and 'segment_id' converted to strings. - - 'xyz' converted from string representation to tuple - """ - # Read and process - merge_sites_df = pd.read_csv(path) - merge_sites_df["brain_id"] = merge_sites_df["brain_id"].apply(str) - merge_sites_df["segment_id"] = merge_sites_df["segment_id"].apply(str) - merge_sites_df["xyz"] = merge_sites_df["xyz"].apply(ast.literal_eval) - - # Check whether test run - if is_test: - idx_mask = merge_sites_df["brain_id"] == TEST_BRAIN - return merge_sites_df[idx_mask].reset_index(drop=True) - else: - return merge_sites_df - - -# --- Helpers --- -def get_brain_ids(merge_sites_df, is_test=False): - """ - Gets brain IDs of datasets to be loaded. - - Parameters - ---------- - merge_sites_df : pandas.DataFrame - DataFrame containing merge sites, must contain the columns: - "brain_id", "segmentation_id", "segment_id", and "xyz". - is_test : bool, optional - Indication of whether this is a test run so only fragments from a - single brain should be loaded. Default is False. - - Returns - ------- - List[str] - Brain IDs of datasests to be loaded. - """ - return [TEST_BRAIN] if is_test else merge_sites_df["brain_id"].unique() - - -def read_idxs(path): - """ - Reads a list of indices from a CSV file. - - Parameters - ---------- - path : str - Path to the CSV file. - - Returns - ------- - List[int] - Indices extracted from the CSV file. - """ - return pd.read_csv(path)["Indices"] diff --git a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py index 488463df..d9b336a1 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py @@ -63,7 +63,7 @@ class BrainDataset: """ giant_component_cable_length = 30000 - random_branching_site_probability = 0.7 + random_branching_site_probability = 0.5 def __init__( self, @@ -99,7 +99,6 @@ def __init__( # Store dataset info self.set_giant_components() self.set_merge_site_info() - self.set_valid_branching_nodes() def load_fragments(self, config, swcs_path): graph = SkeletonGraph( @@ -118,7 +117,7 @@ def load_sites(self, sites_prefix): swc_reader = swc_util.Reader(verbose=False) for swc_dict in swc_reader(sites_prefix): xyz = swc_dict["xyz"][0] - dd, ii = self.graph.kdtree.query(xyz) + dd, ii = self.kdtree.query(xyz) sites.append( {"xyz": xyz, "node": ii, "filename": swc_dict["swc_name"]} ) @@ -132,48 +131,25 @@ def set_merge_site_info(self): # Store fragment IDs corresponding to merge sites for xyz in self.merge_sites["xyz"]: - _, ii = self.graph.kdtree.query(xyz) - self.ignore_fragments.add(self.graph.node_component_id[ii]) + _, ii = self.kdtree.query(xyz) + #self.ignore_fragments.add(self.node_component_id[ii]) def set_giant_components(self): for nodes in map(list, nx.connected_components(self.graph)): # Compute cable length root = util.sample_once(nodes) - cable_length = self.graph.cable_length( + cable_length = self.cable_length( max_depth=self.giant_component_cable_length, root=root ) # Check if giant component if cable_length > self.giant_component_cable_length: - self.ignore_fragments.add(self.graph.node_component_id[root]) - - def set_valid_branching_nodes(self): - self.valid_branching_nodes = set() - for node in self.graph.branching_nodes(): - # Reject if high-degree - if self.graph.degree(node) > 3: - continue - - # Reject if node belongs to fragment with merge - # if self.graph.node_component_id[node] in self.ignore_fragments: - # continue - - # Reject if branching and near another branching node - if self._has_nearby_branching(node): - continue - - # Reject if near merge site - dd, _ = self.merge_sites_kdtree.query(self.graph.node_xyz[node]) - if dd < 50: - continue - - # Node is valid - self.valid_branching_nodes.add(node) + self.ignore_fragments.add(self.node_component_id[root]) # --- Site Retrieval --- def __getitem__(self, idx): node, label = self.get_site(idx) - subgraph = self.graph.rooted_subgraph(node, self.subgraph_depth) + subgraph = self.rooted_subgraph(node, self.subgraph_depth) patches = self.patch_loader(node) return patches, subgraph, label @@ -188,11 +164,20 @@ def get_site(self, idx): return self.get_random_nonmerge_site() def get_random_nonmerge_site(self): - if self.use_branching(): - node = util.sample_once(self.valid_branching_nodes) - else: - node = util.sample_once(self.graph.nodes) - return node, 0 + use_br = np.random.random() < self.random_branching_site_probability + nodes = self.branching_nodes() if use_br else self.nodes + n_attempts = 0 + while True: + # Sample node + node = util.sample_once(nodes) + if self.is_valid_nonmerge_site(node): + return node + + # Try again + n_attempts += 1 + if n_attempts > 100: + print(f"Failed to find valid random nonmerge site for {self.brain_id}!") + return util.sample_once(self.nodes) # --- Helpers --- def add_nonmerge_sites(self, num_sites): @@ -202,7 +187,7 @@ def add_nonmerge_sites(self, num_sites): node, _ = self.get_random_nonmerge_site() site = { "node": node, - "xyz": self.graph.node_xyz[node], + "xyz": self.node_xyz[node], "filename": "random", } new_sites.append(site) @@ -216,22 +201,39 @@ def add_nonmerge_sites(self, num_sites): else: self.nonmerge_sites = pd.DataFrame(new_sites) - def use_branching(self): - outcome = np.random.random() < self.random_branching_site_probability - return outcome and self.valid_branching_nodes + def is_valid_nonmerge_site(self, node): + # Reject if high-degree + if self.degree(node) > 3: + return False + + # Reject if node belongs to ignored fragment + if self.node_component_id[node] in self.ignore_fragments: + return False + + # Reject if branching and near another branching node + if self._has_nearby_branching(node): + return False + + # Reject if near merge site + dd, _ = self.merge_sites_kdtree.query(self.node_xyz[node]) + if dd < 100: + return False - def _has_nearby_branching(self, root, max_depth=60): + # Site is valid + return True + + def _has_nearby_branching(self, root, max_depth=100): queue = [(root, 0)] visited = {root} while queue: # Visit node i, d_i = queue.pop() - if self.graph.degree[i] > 2 and d_i > 0: + if self.degree[i] > 2 and d_i > 0: return True # Update queue - for j in self.graph.neighbors(i): - d_j = d_i + self.graph.dist(i, j) + for j in self.neighbors(i): + d_j = d_i + self.dist(i, j) if j not in visited and d_j < max_depth: queue.append((j, d_j)) visited.add(j) @@ -249,6 +251,9 @@ def _list_indices(self): neg_idxs = np.random.choice(n_neg, size=size, replace=False) return np.concatenate((-neg_idxs, np.arange(n_pos))) + def __getattr__(self, name): + return getattr(self.graph, name) + def __len__(self): """ Counts the number of examples in the dataset. @@ -304,8 +309,50 @@ def _build_index_table(self): table.append((b_idx, int(local_idx))) return table + def save_val_summary(self, output_dir): + """ + Saves a summary of the validation dataset to a CSV file. + + Parameters + ---------- + output_dir : str + Directory to save the CSV file in. + """ + rows = [] + for brain_dataset in self.datasets: + brain_id = brain_dataset.brain_id + + # Merge sites + for _, row in brain_dataset.merge_sites.iterrows(): + rows.append({ + "brain_id": brain_id, + "swc_name": row["filename"], + "xyz": row["xyz"], + "label": "merge", + }) + + # Nonmerge sites + for _, row in brain_dataset.nonmerge_sites.iterrows(): + rows.append({ + "brain_id": brain_id, + "swc_name": row["filename"], + "xyz": row["xyz"], + "label": "nonmerge", + }) + + df = pd.DataFrame(rows) + df.to_csv(os.path.join(output_dir, "val_summary.csv"), index=False) + # --- Dataset Interface --- def __len__(self): + """ + Gets the number of examples in the dataset. + + Returns + ------- + int + Number of examples in the dataset. + """ return len(self._index_table) def __getitem__(self, idx): @@ -355,8 +402,8 @@ def __init__( is_multimodal=False, modality=None, sampler=None, - use_shuffle=True, - prefetch=8, + shuffle=True, + prefetch=32, ): # Check that modality is valid if modality not in self._VALID_MODALITIES: @@ -371,7 +418,7 @@ def __init__( # Instance attributes self.is_multimodal = is_multimodal self.modality = modality - self.use_shuffle = use_shuffle + self.shuffle = shuffle self.prefetch = prefetch self.img_shape = (2,) + dataset.datasets[0].patch_loader.patch_shape @@ -387,7 +434,7 @@ def __iter__(self): # Extract indices self.dataset._index_table = self.dataset._build_index_table() idxs = self.dataset.get_idxs() - if self.use_shuffle: + if self.shuffle: np.random.shuffle(idxs) # Split into batches upfront @@ -584,9 +631,10 @@ def create_dataset_collection( img_path = os.path.join(img_prefixes[brain_id], "0") segmentation_id = get_segmentation_id(sites_root_path, brain_id) sites_path = os.path.join(sites_root_path, brain_id, segmentation_id) - swcs_path = util.get_google_swcs_prefix( - swcs_root_path, brain_id, segmentation_id - ) + swcs_path = os.path.join(swcs_root_path, brain_id, segmentation_id, "fragments") + #util.get_google_swcs_prefix( + # swcs_root_path, brain_id, segmentation_id + #) # Add dataset print(f" \nBrain ID [{i}/{len(brain_ids)}]: {brain_id}") @@ -606,7 +654,7 @@ def create_dataset_collection( # Check whether to generate examples for validation if dataset_mode == "Val": - num_target_neg = 10 * len(dataset.merge_sites) + num_target_neg = 5 * len(dataset.merge_sites) num_added_neg = num_target_neg - len(dataset.nonmerge_sites) dataset.add_nonmerge_sites(num_added_neg) From 7d72edf7e8b0b0cb23e8f886a6b740b3b44134d5 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sun, 14 Jun 2026 23:14:57 +0000 Subject: [PATCH 11/11] refactor: miscellaneous --- .../merge_proofreading/merge_datamodules.py | 11 +++++---- src/neuron_proofreader/utils/swc_util.py | 24 +++++++++---------- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py index d9b336a1..4687409c 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py @@ -63,7 +63,7 @@ class BrainDataset: """ giant_component_cable_length = 30000 - random_branching_site_probability = 0.5 + random_branching_site_probability = 0.25 def __init__( self, @@ -171,13 +171,13 @@ def get_random_nonmerge_site(self): # Sample node node = util.sample_once(nodes) if self.is_valid_nonmerge_site(node): - return node + return node, 0 # Try again n_attempts += 1 if n_attempts > 100: print(f"Failed to find valid random nonmerge site for {self.brain_id}!") - return util.sample_once(self.nodes) + return util.sample_once(self.nodes), 0 # --- Helpers --- def add_nonmerge_sites(self, num_sites): @@ -228,7 +228,7 @@ def _has_nearby_branching(self, root, max_depth=100): while queue: # Visit node i, d_i = queue.pop() - if self.degree[i] > 2 and d_i > 0: + if self.degree[i] >= 3 and i != root: return True # Update queue @@ -608,6 +608,7 @@ def create_dataset_collection( graph_config=None, img_config=None, subgraph_depth=100, + val_neg_multiplier=5, ): # Set parameters based on mode print(f"\nLoading {dataset_mode} Dataset...") @@ -654,7 +655,7 @@ def create_dataset_collection( # Check whether to generate examples for validation if dataset_mode == "Val": - num_target_neg = 5 * len(dataset.merge_sites) + num_target_neg = val_neg_multiplier * len(dataset.merge_sites) num_added_neg = num_target_neg - len(dataset.nonmerge_sites) dataset.add_nonmerge_sites(num_added_neg) diff --git a/src/neuron_proofreader/utils/swc_util.py b/src/neuron_proofreader/utils/swc_util.py index 00e828fa..3be81da2 100644 --- a/src/neuron_proofreader/utils/swc_util.py +++ b/src/neuron_proofreader/utils/swc_util.py @@ -159,7 +159,7 @@ def read_swcs(self, swc_paths): Parameters ---------- swc_paths : List[str] - List of paths to SWC files to be read. + Paths to SWC files to be read. Returns ------- @@ -448,14 +448,13 @@ def parse(self, content, filename): def process_content(self, content): """ - Processes lines of text from an SWC file, extracting an offset - value and returning the remaining content starting from the line - immediately after the last commented line. + Extracts an offset and returns the remaining content starting from the + line after the last commented line. Parameters ---------- content : List[str] - List of strings such that each is a line from an SWC file. + Lines from an SWC file. Returns ------- @@ -475,7 +474,7 @@ def process_content(self, content): def read_coordinate(self, xyz_str, offset=(0, 0, 0)): """ - Reads a coordinate from a string and converts it to voxel coordinates. + Reads coordinate from a string and converts it to voxel coordinates. Parameters ---------- @@ -498,7 +497,7 @@ def write_points( zip_path, points, color=None, prefix="", radius=10, write_mode="w" ): """ - Writes a list of 3D points to individual SWC files in the specified + Writes list of 3D points to individual SWC files in the specified directory. Parameters @@ -524,15 +523,14 @@ def write_points( def to_zipped_point(zf, filename, xyz, color=None, radius=5): """ - Writes a point to an SWC file format, which is then stored in a ZIP - archive. + Writes point to an SWC file in a ZIP archive. Parameters ---------- zf : zipfile.ZipFile ZipFile used to write the generated SWC file. filename : str - Filename of SWC file. + SWC filename. xyz : ArrayLike Point to be written to SWC file. color : str, optional @@ -557,7 +555,7 @@ def to_zipped_point(zf, filename, xyz, color=None, radius=5): # --- Helpers --- def get_segment_id(swc_name): """ - Extract the segment ID from an SWC filename. + Extracts the segment ID from an SWC filename. Parameters ---------- @@ -578,7 +576,7 @@ def get_segment_id(swc_name): def get_swc_name(path): """ - Gets name of the SWC file at the given path, minus the extension. + Gets SWC filename at the given path, minus the extension. Parameters ---------- @@ -588,7 +586,7 @@ def get_swc_name(path): Returns ------- name : str - Name of the SWC file, minus the extension. + SWC filename minus the extension. """ return os.path.splitext(os.path.basename(path))[0]