Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/neuron_proofreader/machine_learning/image_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def create_mask(self, center, shape, node):
subgraph = self.graph.rooted_subgraph(node, depth)

# Annotate mask
mask = np.zeros(shape)
mask = np.zeros(shape, dtype=np.float32)
#self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP
self.annotate_fragment(mask, subgraph, offset, fill_val=1)
return mask
Expand All @@ -199,7 +199,7 @@ def read_image(self, center, shape):
patch = self.img.read(center, shape)
patch = np.minimum(patch, self.brightness_clip)
patch = img_util.normalize(patch, percentiles=self.percentiles)
return patch
return patch.astype(np.float32)

# --- Helpers ---
def __getattr__(self, name):
Expand Down
89 changes: 43 additions & 46 deletions src/neuron_proofreader/merge_proofreading/search_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@

from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from queue import Queue
from threading import Thread
from torch.utils.data import IterableDataset

import itertools
import networkx as nx
import numpy as np
import torch
Expand Down Expand Up @@ -59,27 +62,38 @@ def __init__(

# --- Core routines ---
def __iter__(self):
sites_generator = self._all_sites()
with ThreadPoolExecutor(max_workers=self.prefetch) as executor:
# Submit initial threads
pending = {}
for _ in range(self.prefetch):
self._submit(executor, sites_generator, pending)

# Process remaining jobs until complete
while pending:
done, _ = wait(pending.keys(), return_when=FIRST_COMPLETED)
for thread in done:
site = pending.pop(thread)
yield from self.get_input(site, *thread.result())
self._submit(executor, sites_generator, pending)

def _submit(self, executor, sites_generator, pending):
try:
site = next(sites_generator)
pending[executor.submit(self.patch_loader, site)] = site
except StopIteration:
pass
patch_queue = Queue(maxsize=self.prefetch)
sentinel = object()

def producer():
sites = self._all_sites()
with ThreadPoolExecutor(max_workers=self.prefetch) as executor:
futures = {}

def fill():
while len(futures) < self.prefetch:
try:
site = next(sites)
futures[executor.submit(self.patch_loader, site)] = site
except StopIteration:
break

fill()
while futures:
done, _ = wait(futures, return_when=FIRST_COMPLETED)
for f in done:
patch_queue.put((futures.pop(f), f.result()))
fill()

patch_queue.put(sentinel)

Thread(target=producer, daemon=True).start()

while True:
item = patch_queue.get()
if item is sentinel:
break
yield from self.get_input(*item)

def _all_sites(self):
visited_ids = set()
Expand Down Expand Up @@ -231,29 +245,7 @@ def get_patch(self, nodes, img, offset):
yield node, img[(slice(0, 2), *s)]

def generate_patch_and_pc(self, nodes, img, offset):
# Initializations
label_mask = self.get_label_mask(nodes, img.shape, offset)
patch_centers = self.get_patch_centers(nodes) - offset

# Populate batch array
patches = np.empty((len(nodes), 2,) + self.patch_shape)
point_clouds = np.empty((len(nodes), 3, 3600), dtype=np.float32)
for i, (node, center) in enumerate(zip(nodes, patch_centers)):
s = img_util.get_slices(center, self.patch_shape)
patches[i, 0, ...] = img_util.normalize(img[s])
patches[i, 1, ...] = label_mask[s]

subgraph = self.graph.rooted_subgraph(node, self.subgraph_radius)
point_clouds[i] = subgraph_to_point_cloud(subgraph)

# Build batch dictionary
batch = ml_util.TensorDict(
{
"img": ml_util.to_tensor(patches),
"point_cloud": ml_util.to_tensor(point_clouds),
}
)
return nodes, batch
pass

# --- Helpers ---
def estimate_iterations(self):
Expand All @@ -276,7 +268,7 @@ def estimate_iterations(self):

# Report results
print("# Fragments:", n_fragments)
print(f"Total Cable Length: {total_cable_length / 1000:.2f}mm")
print(f"Total Cable Length: {total_cable_length / 10**5:.2f}cm")
return int(total_cable_length / self.step_size)


Expand Down Expand Up @@ -310,9 +302,14 @@ def __init__(
self.patch_loader = DetectionPatchLoader(self.graph, img_config)
self.search_mode = "branching_nodes"

def estimate_iterations(self):
return len(self.branching_nodes())

def generate_component_sites(self, root):
visited = set()
for i, j in nx.dfs_edges(self.graph, source=root):
if self.degree[i] >= 3:
if self.degree[i] >= 3 and i not in visited:
visited.add(i)
yield i

def get_patch(self, node, img):
Expand Down
Loading