diff --git a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py index d9b336a..4687409 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py @@ -63,7 +63,7 @@ class BrainDataset: """ giant_component_cable_length = 30000 - random_branching_site_probability = 0.5 + random_branching_site_probability = 0.25 def __init__( self, @@ -171,13 +171,13 @@ def get_random_nonmerge_site(self): # Sample node node = util.sample_once(nodes) if self.is_valid_nonmerge_site(node): - return node + return node, 0 # Try again n_attempts += 1 if n_attempts > 100: print(f"Failed to find valid random nonmerge site for {self.brain_id}!") - return util.sample_once(self.nodes) + return util.sample_once(self.nodes), 0 # --- Helpers --- def add_nonmerge_sites(self, num_sites): @@ -228,7 +228,7 @@ def _has_nearby_branching(self, root, max_depth=100): while queue: # Visit node i, d_i = queue.pop() - if self.degree[i] > 2 and d_i > 0: + if self.degree[i] >= 3 and i != root: return True # Update queue @@ -608,6 +608,7 @@ def create_dataset_collection( graph_config=None, img_config=None, subgraph_depth=100, + val_neg_multiplier=5, ): # Set parameters based on mode print(f"\nLoading {dataset_mode} Dataset...") @@ -654,7 +655,7 @@ def create_dataset_collection( # Check whether to generate examples for validation if dataset_mode == "Val": - num_target_neg = 5 * len(dataset.merge_sites) + num_target_neg = val_neg_multiplier * len(dataset.merge_sites) num_added_neg = num_target_neg - len(dataset.nonmerge_sites) dataset.add_nonmerge_sites(num_added_neg) diff --git a/src/neuron_proofreader/utils/swc_util.py b/src/neuron_proofreader/utils/swc_util.py index 00e828f..3be81da 100644 --- a/src/neuron_proofreader/utils/swc_util.py +++ b/src/neuron_proofreader/utils/swc_util.py @@ -159,7 +159,7 @@ def read_swcs(self, swc_paths): Parameters ---------- swc_paths : List[str] - List of paths to SWC files to be read. + Paths to SWC files to be read. Returns ------- @@ -448,14 +448,13 @@ def parse(self, content, filename): def process_content(self, content): """ - Processes lines of text from an SWC file, extracting an offset - value and returning the remaining content starting from the line - immediately after the last commented line. + Extracts an offset and returns the remaining content starting from the + line after the last commented line. Parameters ---------- content : List[str] - List of strings such that each is a line from an SWC file. + Lines from an SWC file. Returns ------- @@ -475,7 +474,7 @@ def process_content(self, content): def read_coordinate(self, xyz_str, offset=(0, 0, 0)): """ - Reads a coordinate from a string and converts it to voxel coordinates. + Reads coordinate from a string and converts it to voxel coordinates. Parameters ---------- @@ -498,7 +497,7 @@ def write_points( zip_path, points, color=None, prefix="", radius=10, write_mode="w" ): """ - Writes a list of 3D points to individual SWC files in the specified + Writes list of 3D points to individual SWC files in the specified directory. Parameters @@ -524,15 +523,14 @@ def write_points( def to_zipped_point(zf, filename, xyz, color=None, radius=5): """ - Writes a point to an SWC file format, which is then stored in a ZIP - archive. + Writes point to an SWC file in a ZIP archive. Parameters ---------- zf : zipfile.ZipFile ZipFile used to write the generated SWC file. filename : str - Filename of SWC file. + SWC filename. xyz : ArrayLike Point to be written to SWC file. color : str, optional @@ -557,7 +555,7 @@ def to_zipped_point(zf, filename, xyz, color=None, radius=5): # --- Helpers --- def get_segment_id(swc_name): """ - Extract the segment ID from an SWC filename. + Extracts the segment ID from an SWC filename. Parameters ---------- @@ -578,7 +576,7 @@ def get_segment_id(swc_name): def get_swc_name(path): """ - Gets name of the SWC file at the given path, minus the extension. + Gets SWC filename at the given path, minus the extension. Parameters ---------- @@ -588,7 +586,7 @@ def get_swc_name(path): Returns ------- name : str - Name of the SWC file, minus the extension. + SWC filename minus the extension. """ return os.path.splitext(os.path.basename(path))[0]