Skip to content
11 changes: 6 additions & 5 deletions src/neuron_proofreader/merge_proofreading/merge_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class BrainDataset:
"""

giant_component_cable_length = 30000
random_branching_site_probability = 0.5
random_branching_site_probability = 0.25

def __init__(
self,
Expand Down Expand Up @@ -171,13 +171,13 @@ def get_random_nonmerge_site(self):
# Sample node
node = util.sample_once(nodes)
if self.is_valid_nonmerge_site(node):
return node
return node, 0

# Try again
n_attempts += 1
if n_attempts > 100:
print(f"Failed to find valid random nonmerge site for {self.brain_id}!")
return util.sample_once(self.nodes)
return util.sample_once(self.nodes), 0

# --- Helpers ---
def add_nonmerge_sites(self, num_sites):
Expand Down Expand Up @@ -228,7 +228,7 @@ def _has_nearby_branching(self, root, max_depth=100):
while queue:
# Visit node
i, d_i = queue.pop()
if self.degree[i] > 2 and d_i > 0:
if self.degree[i] >= 3 and i != root:
return True

# Update queue
Expand Down Expand Up @@ -608,6 +608,7 @@ def create_dataset_collection(
graph_config=None,
img_config=None,
subgraph_depth=100,
val_neg_multiplier=5,
):
# Set parameters based on mode
print(f"\nLoading {dataset_mode} Dataset...")
Expand Down Expand Up @@ -654,7 +655,7 @@ def create_dataset_collection(

# Check whether to generate examples for validation
if dataset_mode == "Val":
num_target_neg = 5 * len(dataset.merge_sites)
num_target_neg = val_neg_multiplier * len(dataset.merge_sites)
num_added_neg = num_target_neg - len(dataset.nonmerge_sites)
dataset.add_nonmerge_sites(num_added_neg)

Expand Down
24 changes: 11 additions & 13 deletions src/neuron_proofreader/utils/swc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def read_swcs(self, swc_paths):
Parameters
----------
swc_paths : List[str]
List of paths to SWC files to be read.
Paths to SWC files to be read.

Returns
-------
Expand Down Expand Up @@ -448,14 +448,13 @@ def parse(self, content, filename):

def process_content(self, content):
"""
Processes lines of text from an SWC file, extracting an offset
value and returning the remaining content starting from the line
immediately after the last commented line.
Extracts an offset and returns the remaining content starting from the
line after the last commented line.

Parameters
----------
content : List[str]
List of strings such that each is a line from an SWC file.
Lines from an SWC file.

Returns
-------
Expand All @@ -475,7 +474,7 @@ def process_content(self, content):

def read_coordinate(self, xyz_str, offset=(0, 0, 0)):
"""
Reads a coordinate from a string and converts it to voxel coordinates.
Reads coordinate from a string and converts it to voxel coordinates.

Parameters
----------
Expand All @@ -498,7 +497,7 @@ def write_points(
zip_path, points, color=None, prefix="", radius=10, write_mode="w"
):
"""
Writes a list of 3D points to individual SWC files in the specified
Writes list of 3D points to individual SWC files in the specified
directory.

Parameters
Expand All @@ -524,15 +523,14 @@ def write_points(

def to_zipped_point(zf, filename, xyz, color=None, radius=5):
"""
Writes a point to an SWC file format, which is then stored in a ZIP
archive.
Writes point to an SWC file in a ZIP archive.

Parameters
----------
zf : zipfile.ZipFile
ZipFile used to write the generated SWC file.
filename : str
Filename of SWC file.
SWC filename.
xyz : ArrayLike
Point to be written to SWC file.
color : str, optional
Expand All @@ -557,7 +555,7 @@ def to_zipped_point(zf, filename, xyz, color=None, radius=5):
# --- Helpers ---
def get_segment_id(swc_name):
"""
Extract the segment ID from an SWC filename.
Extracts the segment ID from an SWC filename.

Parameters
----------
Expand All @@ -578,7 +576,7 @@ def get_segment_id(swc_name):

def get_swc_name(path):
"""
Gets name of the SWC file at the given path, minus the extension.
Gets SWC filename at the given path, minus the extension.

Parameters
----------
Expand All @@ -588,7 +586,7 @@ def get_swc_name(path):
Returns
-------
name : str
Name of the SWC file, minus the extension.
SWC filename minus the extension.
"""
return os.path.splitext(os.path.basename(path))[0]

Expand Down
Loading