diff --git a/src/neuron_proofreader/machine_learning/image_dataloader.py b/src/neuron_proofreader/machine_learning/image_dataloader.py index 790969c3..b1fd6a5c 100644 --- a/src/neuron_proofreader/machine_learning/image_dataloader.py +++ b/src/neuron_proofreader/machine_learning/image_dataloader.py @@ -175,7 +175,7 @@ def create_mask(self, center, shape, node): subgraph = self.graph.rooted_subgraph(node, depth) # Annotate mask - mask = np.zeros(shape) + mask = np.zeros(shape, dtype=np.float32) #self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP self.annotate_fragment(mask, subgraph, offset, fill_val=1) return mask @@ -199,7 +199,7 @@ def read_image(self, center, shape): patch = self.img.read(center, shape) patch = np.minimum(patch, self.brightness_clip) patch = img_util.normalize(patch, percentiles=self.percentiles) - return patch + return patch.astype(np.float32) # --- Helpers --- def __getattr__(self, name): diff --git a/src/neuron_proofreader/merge_proofreading/search_datasets.py b/src/neuron_proofreader/merge_proofreading/search_datasets.py index 473685d1..52833cbf 100644 --- a/src/neuron_proofreader/merge_proofreading/search_datasets.py +++ b/src/neuron_proofreader/merge_proofreading/search_datasets.py @@ -11,8 +11,11 @@ from abc import ABC, abstractmethod from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait +from queue import Queue +from threading import Thread from torch.utils.data import IterableDataset +import itertools import networkx as nx import numpy as np import torch @@ -59,27 +62,38 @@ def __init__( # --- Core routines --- def __iter__(self): - sites_generator = self._all_sites() - with ThreadPoolExecutor(max_workers=self.prefetch) as executor: - # Submit initial threads - pending = {} - for _ in range(self.prefetch): - self._submit(executor, sites_generator, pending) - - # Process remaining jobs until complete - while pending: - done, _ = wait(pending.keys(), return_when=FIRST_COMPLETED) - for thread in done: - site = pending.pop(thread) - yield from self.get_input(site, *thread.result()) - self._submit(executor, sites_generator, pending) - - def _submit(self, executor, sites_generator, pending): - try: - site = next(sites_generator) - pending[executor.submit(self.patch_loader, site)] = site - except StopIteration: - pass + patch_queue = Queue(maxsize=self.prefetch) + sentinel = object() + + def producer(): + sites = self._all_sites() + with ThreadPoolExecutor(max_workers=self.prefetch) as executor: + futures = {} + + def fill(): + while len(futures) < self.prefetch: + try: + site = next(sites) + futures[executor.submit(self.patch_loader, site)] = site + except StopIteration: + break + + fill() + while futures: + done, _ = wait(futures, return_when=FIRST_COMPLETED) + for f in done: + patch_queue.put((futures.pop(f), f.result())) + fill() + + patch_queue.put(sentinel) + + Thread(target=producer, daemon=True).start() + + while True: + item = patch_queue.get() + if item is sentinel: + break + yield from self.get_input(*item) def _all_sites(self): visited_ids = set() @@ -231,29 +245,7 @@ def get_patch(self, nodes, img, offset): yield node, img[(slice(0, 2), *s)] def generate_patch_and_pc(self, nodes, img, offset): - # Initializations - label_mask = self.get_label_mask(nodes, img.shape, offset) - patch_centers = self.get_patch_centers(nodes) - offset - - # Populate batch array - patches = np.empty((len(nodes), 2,) + self.patch_shape) - point_clouds = np.empty((len(nodes), 3, 3600), dtype=np.float32) - for i, (node, center) in enumerate(zip(nodes, patch_centers)): - s = img_util.get_slices(center, self.patch_shape) - patches[i, 0, ...] = img_util.normalize(img[s]) - patches[i, 1, ...] = label_mask[s] - - subgraph = self.graph.rooted_subgraph(node, self.subgraph_radius) - point_clouds[i] = subgraph_to_point_cloud(subgraph) - - # Build batch dictionary - batch = ml_util.TensorDict( - { - "img": ml_util.to_tensor(patches), - "point_cloud": ml_util.to_tensor(point_clouds), - } - ) - return nodes, batch + pass # --- Helpers --- def estimate_iterations(self): @@ -276,7 +268,7 @@ def estimate_iterations(self): # Report results print("# Fragments:", n_fragments) - print(f"Total Cable Length: {total_cable_length / 1000:.2f}mm") + print(f"Total Cable Length: {total_cable_length / 10**5:.2f}cm") return int(total_cable_length / self.step_size) @@ -310,9 +302,14 @@ def __init__( self.patch_loader = DetectionPatchLoader(self.graph, img_config) self.search_mode = "branching_nodes" + def estimate_iterations(self): + return len(self.branching_nodes()) + def generate_component_sites(self, root): + visited = set() for i, j in nx.dfs_edges(self.graph, source=root): - if self.degree[i] >= 3: + if self.degree[i] >= 3 and i not in visited: + visited.add(i) yield i def get_patch(self, node, img):