Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
360 changes: 288 additions & 72 deletions src/neuron_proofreader/geometric_learning/curve_transformer.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions src/neuron_proofreader/machine_learning/image_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
)
from neuron_proofreader.utils import geometry_util, img_util, util


# ----------------------------------------------------------------------------
# Image Class
# ----------------------------------------------------------------------------


class TensorStoreImage:
"""
Class that reads images with the TensorStore library.
Expand Down Expand Up @@ -100,6 +100,7 @@ def shape(self):
# PatchLoader Class
# ----------------------------------------------------------------------------


class PatchLoader(ABC):
"""
A class for reading image patches and generating segment masks.
Expand Down Expand Up @@ -176,7 +177,7 @@ def create_mask(self, center, shape, node):

# Annotate mask
mask = np.zeros(shape, dtype=np.float32)
#self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP
# self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP
self.annotate_fragment(mask, subgraph, offset, fill_val=1)
return mask

Expand Down Expand Up @@ -231,6 +232,7 @@ def stack(img, mask):
# PatchLoader Subclasses
# ----------------------------------------------------------------------------


class DetectionPatchLoader(PatchLoader):

# --- Implementation of Abstract Inferface ---
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def __iter__(self):

# Split into batches upfront
batch_idx_groups = [
idxs[start : min(start + self.batch_size, len(idxs))]
idxs[start: min(start + self.batch_size, len(idxs))]
for start in range(0, len(idxs), self.batch_size)
]

Expand Down
4 changes: 3 additions & 1 deletion src/neuron_proofreader/merge_proofreading/merge_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ def save_results(
):
self.save_sites(output_dir)
if save_fragments:
self.dataset.graph.node_radius = 10 * np.maximum(self.node_preds, 0.1)
self.dataset.graph.node_radius = 10 * np.maximum(
self.node_preds, 0.1
)
fragments_path = os.path.join(output_dir, "fragments.zip")
self.dataset.to_zipped_swcs(fragments_path, use_radius=True)

Expand Down
20 changes: 9 additions & 11 deletions src/neuron_proofreader/merge_proofreading/search_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,10 @@
from threading import Thread
from torch.utils.data import IterableDataset

import itertools
import networkx as nx
import numpy as np
import torch

from neuron_proofreader.machine_learning.point_cloud_models import (
subgraph_to_point_cloud,
)
from neuron_proofreader.machine_learning.image_dataloader import (
DetectionBatchLoader,
DetectionPatchLoader,
Expand Down Expand Up @@ -64,31 +60,33 @@ def __init__(
def __iter__(self):
patch_queue = Queue(maxsize=self.prefetch)
sentinel = object()

def producer():
sites = self._all_sites()
with ThreadPoolExecutor(max_workers=self.prefetch) as executor:
futures = {}

def fill():
while len(futures) < self.prefetch:
try:
site = next(sites)
futures[executor.submit(self.patch_loader, site)] = site
futures[
executor.submit(self.patch_loader, site)
] = site
except StopIteration:
break

fill()
while futures:
done, _ = wait(futures, return_when=FIRST_COMPLETED)
for f in done:
patch_queue.put((futures.pop(f), f.result()))
fill()

patch_queue.put(sentinel)

Thread(target=producer, daemon=True).start()

while True:
item = patch_queue.get()
if item is sentinel:
Expand Down
6 changes: 4 additions & 2 deletions src/neuron_proofreader/proofreading_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def split_proofreading(

def connect_soma_fragments(self, max_dist=25):
self.step_cnt += 1
self.log(f"\nStep {self.step_cnt}: Connect Soma Fragments with dist={max_dist}")
self.log(
f"\nStep {self.step_cnt}: Connect Soma Fragments with dist={max_dist}"
)
summary = self.graph.connect_soma_fragments(max_dist=max_dist)
self.log(summary)

Expand All @@ -173,7 +175,7 @@ def merge_proofreading(self, mode):
swc_util.write_points(
zip_path, merge_sites, color=color, prefix="merge_site", radius=10
)

# --- Helpers ---
def log(self, txt):
"""
Expand Down
1 change: 1 addition & 0 deletions src/neuron_proofreader/skeleton_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def component_to_zipped_swc(self, zipfile, root, use_radius=False):
Indication of whether to preserve radii of nodes or use default
radius of 2μm. Default is False.
"""

# Subroutines
def write_entry(node, parent):
"""
Expand Down
4 changes: 3 additions & 1 deletion src/neuron_proofreader/utils/swc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ def read_zips(self, zip_paths, read_fn):
pbar = self.manual_progress_bar(len(zip_paths))
with ProcessPoolExecutor() as executor:
# Assign processes
futures = {executor.submit(read_fn, path) for path in zip_paths[0:500]} # TEMP
futures = {
executor.submit(read_fn, path) for path in zip_paths[0:500]
} # TEMP

# Store results
swc_dicts = deque()
Expand Down
2 changes: 1 addition & 1 deletion src/neuron_proofreader/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def parse_cloud_path(path):
"""
# Remove s3:// or gs:// if present
if path.startswith("s3://") or path.startswith("gs://"):
path = path[len("s3://") :]
path = path[len("s3://"):]

# Split path
parts = path.split("/", 1)
Expand Down
Loading