From 54053313c0a012f4604ebd966a41c189dc061031 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Wed, 17 Jun 2026 23:09:55 +0000 Subject: [PATCH 1/3] feat: new decoders --- .../geometric_learning/curve_transformer.py | 360 ++++++++++++++---- .../machine_learning/image_dataloader.py | 6 +- .../merge_proofreading/merge_datamodules.py | 2 +- .../merge_proofreading/merge_detection.py | 4 +- .../merge_proofreading/search_datasets.py | 20 +- .../proofreading_pipeline.py | 6 +- src/neuron_proofreader/skeleton_graph.py | 1 + src/neuron_proofreader/utils/swc_util.py | 4 +- src/neuron_proofreader/utils/util.py | 2 +- 9 files changed, 314 insertions(+), 91 deletions(-) diff --git a/src/neuron_proofreader/geometric_learning/curve_transformer.py b/src/neuron_proofreader/geometric_learning/curve_transformer.py index 45cc0daf..533558dc 100644 --- a/src/neuron_proofreader/geometric_learning/curve_transformer.py +++ b/src/neuron_proofreader/geometric_learning/curve_transformer.py @@ -149,113 +149,293 @@ def forward(self, offsets, mask=None): return z, tokens -class CurveDecoder(nn.Module): +class INRDecoder(nn.Module): """ - Transformer decoder that reconstructs a 3D space curve from a latent - vector and the encoder's token representations. Positional queries are - sinusoidally encoded over arc position and biased by the global latent. - The output resolution can differ from the encoder input, allowing - decoding at arbitrary granularity. + Implicit neural representation decoder. Models the curve as F(z, t) -> + offsets where t in [0, 1] is the normalized arc position of each segment. + Supports three MLP variants: plain, residual (ResNet-style), and SIREN. """ def __init__( self, - n_points=100, segment_len=10, - d_token=64, - n_heads=4, - n_layers=4, - d_ff=64, latent_dim=32, + d_hidden=256, + n_layers=4, + n_frequencies=16, + mlp_type="residual", dropout=0.1, ): """ Parameters ---------- - n_points : int - Default number of output curve points. segment_len : int - Number of points per segment token (must match encoder). - d_token : int - Dimension of each token throughout the transformer. - n_heads : int - Number of attention heads. - n_layers : int - Number of transformer decoder layers. - d_ff : int - Feed-forward hidden dimension. + Number of points per segment. latent_dim : int Dimension of the input latent vector. + d_hidden : int + Hidden dimension of the MLP. + n_layers : int + Number of hidden layers. + n_frequencies : int + Number of Fourier frequency bands for positional encoding of t. + Not used when mlp_type is 'siren' (SIREN handles PE internally). + mlp_type : str + One of 'plain', 'residual', or 'siren'. dropout : float - Dropout probability. + Dropout probability. Not used in SIREN. """ super().__init__() self.segment_len = segment_len - self.n_segments = n_points // segment_len + self.n_frequencies = n_frequencies + self.mlp_type = mlp_type + + d_out = segment_len * 3 + + if mlp_type == "siren": + # SIREN handles position encoding internally via sin activations — + # just concatenate raw t (scalar) with z + d_input = latent_dim + 1 + layers = [SirenLayer(d_input, d_hidden, is_first=True)] + for _ in range(n_layers - 1): + layers.append(SirenLayer(d_hidden, d_hidden)) + layers.append(nn.Linear(d_hidden, d_out)) + self.mlp = nn.Sequential(*layers) + + elif mlp_type == "residual": + d_input = latent_dim + 2 * n_frequencies + self.input_proj = nn.Sequential( + nn.Linear(d_input, d_hidden), + nn.LayerNorm(d_hidden), + nn.GELU(), + ) + self.blocks = nn.ModuleList( + [ResidualBlock(d_hidden, dropout) for _ in range(n_layers)] + ) + self.output_proj = nn.Linear(d_hidden, d_out) + + elif mlp_type == "plain": + d_input = latent_dim + 2 * n_frequencies + layers = [ + nn.Linear(d_input, d_hidden), + nn.GELU(), + nn.Dropout(dropout), + ] + for _ in range(n_layers - 1): + layers += [ + nn.Linear(d_hidden, d_hidden), + nn.GELU(), + nn.Dropout(dropout), + ] + layers.append(nn.Linear(d_hidden, d_out)) + self.mlp = nn.Sequential(*layers) - # Project latent to d_token to bias the positional queries - self.latent_proj = nn.Linear(latent_dim, d_token) + else: + raise ValueError(f"Unknown mlp_type: {mlp_type}") - decoder_layer = nn.TransformerDecoderLayer( - d_model=d_token, - nhead=n_heads, - dim_feedforward=d_ff, - dropout=dropout, + def positional_encoding(self, t): + """ + Fourier positional encoding for scalar arc positions. + + Parameters + ---------- + t : torch.Tensor + Arc positions of shape (n_segments,) in [0, 1]. + + Returns + ------- + torch.Tensor + Shape (n_segments, 2 * n_frequencies). + """ + freqs = 2 ** torch.arange(self.n_frequencies, device=t.device).float() + x = t.unsqueeze(-1) * freqs * torch.pi + return torch.cat([torch.sin(x), torch.cos(x)], dim=-1) + + def forward(self, z, n_segments=None): + """ + Parameters + ---------- + z : torch.Tensor + Latent vector of shape (B, latent_dim). + n_segments : int, optional + Number of output segments. + + Returns + ------- + torch.Tensor + Reconstructed offset sequence of shape (B, n_segments * segment_len, 3). + """ + B = z.shape[0] + n_segments = n_segments or 10 + t = torch.linspace(0, 1, n_segments, device=z.device) + + if self.mlp_type == "siren": + # Raw scalar t concatenated with z — SIREN encodes frequency internally + t_exp = ( + t.unsqueeze(-1).unsqueeze(0).expand(B, -1, -1) + ) # (B, n_seg, 1) + z_exp = z.unsqueeze(1).expand( + -1, n_segments, -1 + ) # (B, n_seg, latent_dim) + x = torch.cat([z_exp, t_exp], dim=-1) # (B, n_seg, latent_dim+1) + out = self.mlp(x) + + elif self.mlp_type == "residual": + pe = self.positional_encoding(t) + z_exp = z.unsqueeze(1).expand(-1, n_segments, -1) + pe_exp = pe.unsqueeze(0).expand(B, -1, -1) + x = torch.cat([z_exp, pe_exp], dim=-1) + x = self.input_proj(x) + for block in self.blocks: + x = block(x) + out = self.output_proj(x) + + else: # plain + pe = self.positional_encoding(t) + z_exp = z.unsqueeze(1).expand(-1, n_segments, -1) + pe_exp = pe.unsqueeze(0).expand(B, -1, -1) + x = torch.cat([z_exp, pe_exp], dim=-1) + out = self.mlp(x) + + return out.reshape(B, n_segments * self.segment_len, 3) + + +class AutoregressiveINRDecoder(nn.Module): + """ + Autoregressive implicit neural representation decoder. At each step, + predicts the i-th segment's offsets conditioned on the global latent z, + the arc position t_i, and the previous segment's predicted offsets T_{i-1}. + + T_i = F(z, t_i | T_{i-1}) + + The GRU hidden state carries long-range context from all previous + segments, while T_{i-1} provides an explicit local continuity signal. + """ + + def __init__( + self, + segment_len=10, + latent_dim=32, + d_hidden=256, + n_layers=4, + n_frequencies=16, + dropout=0.1, + ): + """ + Parameters + ---------- + segment_len : int + Number of points per segment. + latent_dim : int + Dimension of the input latent vector z. + d_hidden : int + GRU hidden dimension. + n_layers : int + Number of GRU layers. + n_frequencies : int + Number of Fourier frequency bands for positional encoding of t. + dropout : float + Dropout probability. + """ + super().__init__() + self.segment_len = segment_len + self.n_frequencies = n_frequencies + d_seg = segment_len * 3 + + # Seed the GRU hidden state from z + self.latent_proj = nn.Linear(latent_dim, d_hidden) + + # Input at each step: [T_{i-1} | pe(t_i) | z] + # Including z directly at every step (alongside the hidden state) lets + # the model re-attend to the global code at each position, rather than + # relying purely on it surviving through the hidden state + d_input = d_seg + 2 * n_frequencies + latent_dim + self.gru = nn.GRU( + input_size=d_input, + hidden_size=d_hidden, + num_layers=n_layers, batch_first=True, + dropout=dropout if n_layers > 1 else 0, ) - self.transformer = nn.TransformerDecoder( - decoder_layer, num_layers=n_layers + self.output_proj = nn.Sequential( + nn.LayerNorm(d_hidden), + nn.Linear(d_hidden, d_seg), ) - self.to_points = nn.Sequential( - nn.LayerNorm(d_token), - nn.Linear(d_token, segment_len * 3), - ) + def positional_encoding(self, t): + """ + Fourier positional encoding for scalar arc positions. + + Parameters + ---------- + t : torch.Tensor + Arc positions of shape (n_segments,) in [0, 1]. + + Returns + ------- + torch.Tensor + Shape (n_segments, 2 * n_frequencies). + """ + freqs = 2 ** torch.arange(self.n_frequencies, device=t.device).float() + x = t.unsqueeze(-1) * freqs * torch.pi + return torch.cat([torch.sin(x), torch.cos(x)], dim=-1) - def forward(self, z, encoder_tokens, encoder_mask=None, n_segments=None): + def forward(self, z, n_segments=None): """ Parameters ---------- z : torch.Tensor Latent vector of shape (B, latent_dim). - encoder_tokens : torch.Tensor - Per-token encoder outputs of shape (B, n_segments + 2, d_token). - encoder_mask : torch.Tensor, optional - Shape (B, n_segments + 2), True where padding. Passed as - memory_key_padding_mask to cross-attention. Default is None. n_segments : int, optional - Number of output segments. Inferred from encoder tokens if not - provided. + Number of output segments to decode. Returns ------- - curve : torch.Tensor - Reconstructed curve of shape (B, n_segments * segment_len, 3). + torch.Tensor + Reconstructed offset sequence of shape (B, n_segments * segment_len, 3). """ B = z.shape[0] - d_token = encoder_tokens.shape[2] - n_segments = n_segments or self.n_segments + n_segments = n_segments or 10 + d_seg = self.segment_len * 3 + + # Seed hidden state from z: (n_layers, B, d_hidden) + h = ( + self.latent_proj(z) + .unsqueeze(0) + .expand(self.gru.num_layers, -1, -1) + .contiguous() + ) - # Sinusoidal queries over arc position, biased by global latent - pe = sinusoidal_encoding(n_segments, d_token, encoder_tokens.device) - latent = self.latent_proj(z).unsqueeze(1) # (B, 1, d_token) - queries = pe.expand(B, -1, -1) + latent # (B, n_seg, d_token) + # Arc positions and Fourier encodings: (n_segments, 2 * n_freq) + t = torch.linspace(0, 1, n_segments, device=z.device) + pe = self.positional_encoding(t) - out = self.transformer( - queries, - encoder_tokens, - memory_key_padding_mask=encoder_mask, - ) # (B, n_seg, d_token) + # Autoregressive loop + outputs = [] + T_prev = torch.zeros(B, d_seg, device=z.device) # T_{-1}: start token + for s in range(n_segments): + pe_s = pe[s].unsqueeze(0).expand(B, -1) # (B, 2*n_freq) - segments = self.to_points(out) # (B, n_seg, seg_len*3) - offsets = segments.reshape(B, n_segments * self.segment_len, 3) - return offsets + # Concatenate T_{i-1}, pe(t_i), and z at every step + x_s = torch.cat([T_prev, pe_s, z], dim=-1).unsqueeze( + 1 + ) # (B, 1, d_input) + out, h = self.gru(x_s, h) # (B, 1, d_hidden) + T_i = self.output_proj(out.squeeze(1)) # (B, d_seg) + + outputs.append(T_i) + T_prev = T_i + + out = torch.stack(outputs, dim=1) # (B, n_segments, d_seg) + return out.reshape(B, n_segments * self.segment_len, 3) class CurveAutoencoder(nn.Module): def __init__( self, + decoder, + decoder_name=None, n_points=100, segment_len=10, d_token=64, @@ -270,6 +450,7 @@ def __init__( # Config self.config = { + "decoder_name": decoder_name, "n_points": n_points, "segment_len": segment_len, "d_token": d_token, @@ -290,16 +471,7 @@ def __init__( latent_dim=latent_dim, dropout=dropout, ) - self.decoder = CurveDecoder( - n_points=n_points, - segment_len=segment_len, - d_token=d_token, - n_heads=n_heads, - n_layers=n_layers, - d_ff=d_ff, - latent_dim=latent_dim, - dropout=dropout, - ) + self.decoder = decoder def forward(self, offsets, token_mask): """ @@ -317,7 +489,7 @@ def forward(self, offsets, token_mask): """ z, encoder_tokens = self.encoder(offsets, token_mask) n_segments = offsets.shape[1] // self.decoder.segment_len - reconstruction = self.decoder(z, encoder_tokens, n_segments=n_segments) + reconstruction = self.decoder(z, n_segments=n_segments) return reconstruction, z def encode(self, offsets): @@ -337,6 +509,50 @@ def load(cls, path): # --- Helpers --- +class ResidualBlock(nn.Module): + def __init__(self, d_hidden, dropout=0.1): + super().__init__() + self.block = nn.Sequential( + nn.Linear(d_hidden, d_hidden), + nn.LayerNorm(d_hidden), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(d_hidden, d_hidden), + nn.LayerNorm(d_hidden), + ) + self.act = nn.GELU() + + def forward(self, x): + return self.act(x + self.block(x)) + + +class SirenLayer(nn.Module): + def __init__(self, d_in, d_out, omega=30.0, is_first=False): + """ + Parameters + ---------- + omega : float + Frequency scaling factor. Default 30 following the original paper. + is_first : bool + First layer uses a different initialization scheme. + """ + super().__init__() + self.omega = omega + self.linear = nn.Linear(d_in, d_out) + self._init_weights(is_first, d_in) + + def _init_weights(self, is_first, d_in): + with torch.no_grad(): + if is_first: + bound = 1 / d_in + else: + bound = np.sqrt(6 / d_in) / self.omega + self.linear.weight.uniform_(-bound, bound) + + def forward(self, x): + return torch.sin(self.omega * self.linear(x)) + + def sinusoidal_encoding(n_tokens, d_token, device): """ Sinusoidal positional encoding over normalised arc position [0, 1]. diff --git a/src/neuron_proofreader/machine_learning/image_dataloader.py b/src/neuron_proofreader/machine_learning/image_dataloader.py index b1fd6a5c..1a965c09 100644 --- a/src/neuron_proofreader/machine_learning/image_dataloader.py +++ b/src/neuron_proofreader/machine_learning/image_dataloader.py @@ -19,11 +19,11 @@ ) from neuron_proofreader.utils import geometry_util, img_util, util - # ---------------------------------------------------------------------------- # Image Class # ---------------------------------------------------------------------------- + class TensorStoreImage: """ Class that reads images with the TensorStore library. @@ -100,6 +100,7 @@ def shape(self): # PatchLoader Class # ---------------------------------------------------------------------------- + class PatchLoader(ABC): """ A class for reading image patches and generating segment masks. @@ -176,7 +177,7 @@ def create_mask(self, center, shape, node): # Annotate mask mask = np.zeros(shape, dtype=np.float32) - #self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP + # self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP self.annotate_fragment(mask, subgraph, offset, fill_val=1) return mask @@ -231,6 +232,7 @@ def stack(img, mask): # PatchLoader Subclasses # ---------------------------------------------------------------------------- + class DetectionPatchLoader(PatchLoader): # --- Implementation of Abstract Inferface --- diff --git a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py index c85ae38c..d7c7bb8f 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py @@ -442,7 +442,7 @@ def __iter__(self): # Split into batches upfront batch_idx_groups = [ - idxs[start : min(start + self.batch_size, len(idxs))] + idxs[start: min(start + self.batch_size, len(idxs))] for start in range(0, len(idxs), self.batch_size) ] diff --git a/src/neuron_proofreader/merge_proofreading/merge_detection.py b/src/neuron_proofreader/merge_proofreading/merge_detection.py index c0736a06..81e4fb0a 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_detection.py +++ b/src/neuron_proofreader/merge_proofreading/merge_detection.py @@ -176,7 +176,9 @@ def save_results( ): self.save_sites(output_dir) if save_fragments: - self.dataset.graph.node_radius = 10 * np.maximum(self.node_preds, 0.1) + self.dataset.graph.node_radius = 10 * np.maximum( + self.node_preds, 0.1 + ) fragments_path = os.path.join(output_dir, "fragments.zip") self.dataset.to_zipped_swcs(fragments_path, use_radius=True) diff --git a/src/neuron_proofreader/merge_proofreading/search_datasets.py b/src/neuron_proofreader/merge_proofreading/search_datasets.py index 52833cbf..ad761584 100644 --- a/src/neuron_proofreader/merge_proofreading/search_datasets.py +++ b/src/neuron_proofreader/merge_proofreading/search_datasets.py @@ -15,14 +15,10 @@ from threading import Thread from torch.utils.data import IterableDataset -import itertools import networkx as nx import numpy as np import torch -from neuron_proofreader.machine_learning.point_cloud_models import ( - subgraph_to_point_cloud, -) from neuron_proofreader.machine_learning.image_dataloader import ( DetectionBatchLoader, DetectionPatchLoader, @@ -64,31 +60,33 @@ def __init__( def __iter__(self): patch_queue = Queue(maxsize=self.prefetch) sentinel = object() - + def producer(): sites = self._all_sites() with ThreadPoolExecutor(max_workers=self.prefetch) as executor: futures = {} - + def fill(): while len(futures) < self.prefetch: try: site = next(sites) - futures[executor.submit(self.patch_loader, site)] = site + futures[ + executor.submit(self.patch_loader, site) + ] = site except StopIteration: break - + fill() while futures: done, _ = wait(futures, return_when=FIRST_COMPLETED) for f in done: patch_queue.put((futures.pop(f), f.result())) fill() - + patch_queue.put(sentinel) - + Thread(target=producer, daemon=True).start() - + while True: item = patch_queue.get() if item is sentinel: diff --git a/src/neuron_proofreader/proofreading_pipeline.py b/src/neuron_proofreader/proofreading_pipeline.py index 0cb8f8ab..49b1bec3 100644 --- a/src/neuron_proofreader/proofreading_pipeline.py +++ b/src/neuron_proofreader/proofreading_pipeline.py @@ -146,7 +146,9 @@ def split_proofreading( def connect_soma_fragments(self, max_dist=25): self.step_cnt += 1 - self.log(f"\nStep {self.step_cnt}: Connect Soma Fragments with dist={max_dist}") + self.log( + f"\nStep {self.step_cnt}: Connect Soma Fragments with dist={max_dist}" + ) summary = self.graph.connect_soma_fragments(max_dist=max_dist) self.log(summary) @@ -173,7 +175,7 @@ def merge_proofreading(self, mode): swc_util.write_points( zip_path, merge_sites, color=color, prefix="merge_site", radius=10 ) - + # --- Helpers --- def log(self, txt): """ diff --git a/src/neuron_proofreader/skeleton_graph.py b/src/neuron_proofreader/skeleton_graph.py index e26a3f7c..ed649949 100644 --- a/src/neuron_proofreader/skeleton_graph.py +++ b/src/neuron_proofreader/skeleton_graph.py @@ -535,6 +535,7 @@ def component_to_zipped_swc(self, zipfile, root, use_radius=False): Indication of whether to preserve radii of nodes or use default radius of 2μm. Default is False. """ + # Subroutines def write_entry(node, parent): """ diff --git a/src/neuron_proofreader/utils/swc_util.py b/src/neuron_proofreader/utils/swc_util.py index 920af307..fc87b644 100644 --- a/src/neuron_proofreader/utils/swc_util.py +++ b/src/neuron_proofreader/utils/swc_util.py @@ -202,7 +202,9 @@ def read_zips(self, zip_paths, read_fn): pbar = self.manual_progress_bar(len(zip_paths)) with ProcessPoolExecutor() as executor: # Assign processes - futures = {executor.submit(read_fn, path) for path in zip_paths[0:500]} # TEMP + futures = { + executor.submit(read_fn, path) for path in zip_paths[0:500] + } # TEMP # Store results swc_dicts = deque() diff --git a/src/neuron_proofreader/utils/util.py b/src/neuron_proofreader/utils/util.py index df623297..4056cbfb 100644 --- a/src/neuron_proofreader/utils/util.py +++ b/src/neuron_proofreader/utils/util.py @@ -401,7 +401,7 @@ def parse_cloud_path(path): """ # Remove s3:// or gs:// if present if path.startswith("s3://") or path.startswith("gs://"): - path = path[len("s3://") :] + path = path[len("s3://"):] # Split path parts = path.split("/", 1) From 5c2781ce8108589e2f40bd085a8c2acc9dcb858d Mon Sep 17 00:00:00 2001 From: anna-grim Date: Mon, 22 Jun 2026 22:11:39 +0000 Subject: [PATCH 2/3] feat: curve eval --- .../geometric_learning/curve_datamodules.py | 5 +- .../geometric_learning/curve_transformer.py | 44 +++---- .../geometric_learning/curve_visualization.py | 113 ++++++++++++++++++ 3 files changed, 139 insertions(+), 23 deletions(-) create mode 100644 src/neuron_proofreader/geometric_learning/curve_visualization.py diff --git a/src/neuron_proofreader/geometric_learning/curve_datamodules.py b/src/neuron_proofreader/geometric_learning/curve_datamodules.py index e957f447..5492e3ba 100644 --- a/src/neuron_proofreader/geometric_learning/curve_datamodules.py +++ b/src/neuron_proofreader/geometric_learning/curve_datamodules.py @@ -28,11 +28,13 @@ def __init__( swcs_path, graph_config=None, max_length=np.inf, + segment_len=10, transform=None, ): # Instance attributes self.brain_id = brain_id self.max_length = max_length + self.segment_len = segment_len self.transform = transform # Core data structures @@ -53,7 +55,8 @@ def load_skeletons(self, config, swcs_path): def get_valid_paths(self): paths = list() for p in self.irreducible_paths(): - if self.path_length(p) < self.max_length: + length = self.path_length(p) + if length < self.max_length and len(p) > self.segment_len: paths.append(p) return paths diff --git a/src/neuron_proofreader/geometric_learning/curve_transformer.py b/src/neuron_proofreader/geometric_learning/curve_transformer.py index 533558dc..356b0e67 100644 --- a/src/neuron_proofreader/geometric_learning/curve_transformer.py +++ b/src/neuron_proofreader/geometric_learning/curve_transformer.py @@ -271,13 +271,9 @@ def forward(self, z, n_segments=None): if self.mlp_type == "siren": # Raw scalar t concatenated with z — SIREN encodes frequency internally - t_exp = ( - t.unsqueeze(-1).unsqueeze(0).expand(B, -1, -1) - ) # (B, n_seg, 1) - z_exp = z.unsqueeze(1).expand( - -1, n_segments, -1 - ) # (B, n_seg, latent_dim) - x = torch.cat([z_exp, t_exp], dim=-1) # (B, n_seg, latent_dim+1) + t_exp = (t.unsqueeze(-1).unsqueeze(0).expand(B, -1, -1)) + z_exp = z.unsqueeze(1).expand(-1, n_segments, -1) + x = torch.cat([z_exp, t_exp], dim=-1) out = self.mlp(x) elif self.mlp_type == "residual": @@ -414,19 +410,19 @@ def forward(self, z, n_segments=None): outputs = [] T_prev = torch.zeros(B, d_seg, device=z.device) # T_{-1}: start token for s in range(n_segments): - pe_s = pe[s].unsqueeze(0).expand(B, -1) # (B, 2*n_freq) + # Create GRU input + pe_s = pe[s].unsqueeze(0).expand(B, -1) + x_s = torch.cat([T_prev, pe_s, z], dim=-1).unsqueeze(1) - # Concatenate T_{i-1}, pe(t_i), and z at every step - x_s = torch.cat([T_prev, pe_s, z], dim=-1).unsqueeze( - 1 - ) # (B, 1, d_input) - out, h = self.gru(x_s, h) # (B, 1, d_hidden) - T_i = self.output_proj(out.squeeze(1)) # (B, d_seg) + # Predict next token and update hidden state + out, h = self.gru(x_s, h) + T_i = self.output_proj(out.squeeze(1)) + # Update previous token outputs.append(T_i) T_prev = T_i - out = torch.stack(outputs, dim=1) # (B, n_segments, d_seg) + out = torch.stack(outputs, dim=1) return out.reshape(B, n_segments * self.segment_len, 3) @@ -473,12 +469,12 @@ def __init__( ) self.decoder = decoder - def forward(self, offsets, token_mask): + def forward(self, diffs, mask=None): """ Parameters ---------- - offsets : torch.Tensor - Shape (B, N, 3), normalized to the unit sphere, offsets[:, 0] == 0. + diffs : torch.Tensor + Shape (B, N, 3), normalized to the unit sphere, diffs[:, 0] == 0. Returns ------- @@ -487,13 +483,17 @@ def forward(self, offsets, token_mask): z : torch.Tensor Latent vector of shape (B, latent_dim). """ - z, encoder_tokens = self.encoder(offsets, token_mask) - n_segments = offsets.shape[1] // self.decoder.segment_len + z, encoder_tokens = self.encoder(diffs, mask) + if mask is not None: + valid_lengths = (~mask).sum(dim=1) # (B,) + n_segments = (valid_lengths.min() // self.decoder.segment_len).item() + else: + n_segments = diffs.shape[1] // self.decoder.segment_len reconstruction = self.decoder(z, n_segments=n_segments) return reconstruction, z - def encode(self, offsets): - z, _ = self.encoder(offsets) + def encode(self, diffs): + z, _ = self.encoder(diffs) return z # --- Helpers --- diff --git a/src/neuron_proofreader/geometric_learning/curve_visualization.py b/src/neuron_proofreader/geometric_learning/curve_visualization.py new file mode 100644 index 00000000..55ecaf55 --- /dev/null +++ b/src/neuron_proofreader/geometric_learning/curve_visualization.py @@ -0,0 +1,113 @@ +""" +Created on Mon June 8 17:00:00 2026 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Code for visualizing 3D space curves and their embeddings. + +""" + +import matplotlib.pyplot as plt +import numpy as np +import plotly.graph_objects as go + + +# --- Plot Curves --- +def plot_curves(curve1, curve2, name1=None, name2=None): + fig = go.Figure( + data=[ + go.Scatter3d( + x=curve1[:, 0], y=curve1[:, 1], z=curve1[:, 2], + mode='lines', + name=name1, + line=dict(width=5, color='blue') + ), + go.Scatter3d( + x=curve2[:, 0], y=curve2[:, 1], z=curve2[:, 2], + mode='lines', + name=name2, + line=dict(width=5, color='green') + ), + go.Scatter3d( + x=[0], y=[0], z=[0], + mode='markers', + name="Origin", + marker=dict(size=3, color='red') + ) + ] + ) + fig.update_layout( + scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z') + ) + fig.show() + + +def plot_length_distribution(dataset_collection, title=None, output_path=None): + # Compute path length stats + lengths = dataset_collection.examples_df["length"] + p50 = round(np.percentile(lengths, 50), 2) + p99 = round(np.percentile(lengths, 99.9), 2) + thr_lengths = [l for l in lengths if l < p99] + + # Plot path lengths + plt.figure(figsize=(8, 5)) + plt.hist(thr_lengths, bins=50, edgecolor='white', linewidth=0.5, zorder=2) + add_line(p50, color="r", label=f"50th perc = {p50}") + add_line(p99, color="g", label=f"99.9th perc = {p99}") + + # Plot labels + plt.grid(axis='y', color='lightgrey', linewidth=0.5) + plt.legend(loc='upper right') + plt.title(title, fontsize=13) + plt.xlabel('Path Length (μm)', fontsize=12) + plt.ylabel('Count', fontsize=12) + plt.yscale("log") + + plt.figtext( + 0.01, -0.02, + 'Note: Path lengths thresholded at 99.9th percentile in this plot', + fontsize=9, + color='tab:grey', + va='bottom' + ) + plt.subplots_adjust(bottom=0.8) + visualize_result(output_path=output_path) + + +# --- Plot Curve Embeddings --- +def plot_latents_by_direction(latents, output_path=None): + pass + + +def plot_latents_by_length(latents, output_path=None): + # Compute PCA of latents + pca = PCA(n_components=2) + latents_2d = pca.fit_transform(latents) + path_lengths = dataset_collection.examples_df["length"] + + # Visualize latents and color by path length + plt.figure(figsize=(8, 6)) + sc = plt.scatter( + latents_2d[:, 0], latents_2d[:, 1], + c=path_lengths, cmap="viridis", s=10, alpha=0.7, + norm=LogNorm(vmin=path_lengths.min(), vmax=path_lengths.max()) + ) + plt.colorbar(sc, label="Path length (μm)") + plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)') + plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)') + plt.title("PCA of curve embeddings") + visualize_result(output_path=output_path) + + +# --- Helpers --- +def add_line(p, color=None, label=None): + plt.axvline(p, color=color, linestyle="--", label=label, zorder=2) + + +def visualize_result(output_path=None): + plt.tight_layout() + if output_path: + plt.savefig(output_path, dpi=300, bbox_inches="tight") + else: + plt.show() From ea0524bc0d68ae63f70c2275ed00db720858d948 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 23 Jun 2026 02:43:07 +0000 Subject: [PATCH 3/3] feat: eval plotting moved in --- .../geometric_learning/curve_transformer.py | 6 +- .../geometric_learning/curve_visualization.py | 132 +++++++++++++----- .../machine_learning/vision_models.py | 31 ---- .../merge_proofreading/merge_datamodules.py | 2 +- .../merge_proofreading/search_datasets.py | 2 +- src/neuron_proofreader/utils/geometry_util.py | 55 ++++++++ 6 files changed, 160 insertions(+), 68 deletions(-) diff --git a/src/neuron_proofreader/geometric_learning/curve_transformer.py b/src/neuron_proofreader/geometric_learning/curve_transformer.py index 356b0e67..80abb704 100644 --- a/src/neuron_proofreader/geometric_learning/curve_transformer.py +++ b/src/neuron_proofreader/geometric_learning/curve_transformer.py @@ -271,7 +271,7 @@ def forward(self, z, n_segments=None): if self.mlp_type == "siren": # Raw scalar t concatenated with z — SIREN encodes frequency internally - t_exp = (t.unsqueeze(-1).unsqueeze(0).expand(B, -1, -1)) + t_exp = t.unsqueeze(-1).unsqueeze(0).expand(B, -1, -1) z_exp = z.unsqueeze(1).expand(-1, n_segments, -1) x = torch.cat([z_exp, t_exp], dim=-1) out = self.mlp(x) @@ -486,7 +486,9 @@ def forward(self, diffs, mask=None): z, encoder_tokens = self.encoder(diffs, mask) if mask is not None: valid_lengths = (~mask).sum(dim=1) # (B,) - n_segments = (valid_lengths.min() // self.decoder.segment_len).item() + n_segments = ( + valid_lengths.min() // self.decoder.segment_len + ).item() else: n_segments = diffs.shape[1] // self.decoder.segment_len reconstruction = self.decoder(z, n_segments=n_segments) diff --git a/src/neuron_proofreader/geometric_learning/curve_visualization.py b/src/neuron_proofreader/geometric_learning/curve_visualization.py index 55ecaf55..1a4b16a7 100644 --- a/src/neuron_proofreader/geometric_learning/curve_visualization.py +++ b/src/neuron_proofreader/geometric_learning/curve_visualization.py @@ -8,37 +8,49 @@ """ +from colorsys import hsv_to_rgb +from matplotlib.colors import LogNorm +from sklearn.decomposition import PCA + import matplotlib.pyplot as plt import numpy as np import plotly.graph_objects as go +from neuron_proofreader.utils import geometry_util + # --- Plot Curves --- def plot_curves(curve1, curve2, name1=None, name2=None): fig = go.Figure( data=[ go.Scatter3d( - x=curve1[:, 0], y=curve1[:, 1], z=curve1[:, 2], - mode='lines', + x=curve1[:, 0], + y=curve1[:, 1], + z=curve1[:, 2], + mode="lines", name=name1, - line=dict(width=5, color='blue') + line=dict(width=5, color="blue"), ), go.Scatter3d( - x=curve2[:, 0], y=curve2[:, 1], z=curve2[:, 2], - mode='lines', + x=curve2[:, 0], + y=curve2[:, 1], + z=curve2[:, 2], + mode="lines", name=name2, - line=dict(width=5, color='green') + line=dict(width=5, color="green"), ), go.Scatter3d( - x=[0], y=[0], z=[0], - mode='markers', + x=[0], + y=[0], + z=[0], + mode="markers", name="Origin", - marker=dict(size=3, color='red') - ) + marker=dict(size=3, color="red"), + ), ] ) fig.update_layout( - scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z') + scene=dict(xaxis_title="x", yaxis_title="y", zaxis_title="z") ) fig.show() @@ -49,62 +61,116 @@ def plot_length_distribution(dataset_collection, title=None, output_path=None): p50 = round(np.percentile(lengths, 50), 2) p99 = round(np.percentile(lengths, 99.9), 2) thr_lengths = [l for l in lengths if l < p99] - + # Plot path lengths plt.figure(figsize=(8, 5)) - plt.hist(thr_lengths, bins=50, edgecolor='white', linewidth=0.5, zorder=2) + plt.hist(thr_lengths, bins=50, edgecolor="white", linewidth=0.5, zorder=2) add_line(p50, color="r", label=f"50th perc = {p50}") add_line(p99, color="g", label=f"99.9th perc = {p99}") # Plot labels - plt.grid(axis='y', color='lightgrey', linewidth=0.5) - plt.legend(loc='upper right') + plt.grid(axis="y", color="lightgrey", linewidth=0.5) + plt.legend(loc="upper right") plt.title(title, fontsize=13) - plt.xlabel('Path Length (μm)', fontsize=12) - plt.ylabel('Count', fontsize=12) + plt.xlabel("Path Length (μm)", fontsize=12) + plt.ylabel("Count", fontsize=12) plt.yscale("log") plt.figtext( - 0.01, -0.02, - 'Note: Path lengths thresholded at 99.9th percentile in this plot', + 0.01, + -0.02, + "Note: Path lengths thresholded at 99.9th percentile in this plot", fontsize=9, - color='tab:grey', - va='bottom' + color="tab:grey", + va="bottom", ) plt.subplots_adjust(bottom=0.8) visualize_result(output_path=output_path) # --- Plot Curve Embeddings --- -def plot_latents_by_direction(latents, output_path=None): - pass +def plot_error_vs_length(lengths, rmse_results, output_path=None): + # Set colors + norm = LogNorm(vmin=lengths.min(), vmax=lengths.max()) + colors = plt.cm.viridis(norm(lengths)) + + # Plot + plt.figure(figsize=(8, 6)) + plt.scatter(lengths, rmse_results, c=colors, s=10, alpha=0.8) + plt.xscale("log") + plt.yscale("log") + plt.xlabel("Path Length (μm)", fontsize=12) + plt.ylabel("RMSE", fontsize=12) + visualize_result(output_path=output_path) -def plot_latents_by_length(latents, output_path=None): - # Compute PCA of latents +def plot_latents_by_pca(curves, latents, output_path=None): + # PCA of latents pca = PCA(n_components=2) latents_2d = pca.fit_transform(latents) - path_lengths = dataset_collection.examples_df["length"] + lengths = np.array([geometry_util.compute_length(c) for c in curves]) + + # Visualize results + _plot_latents_by_direction(curves, latents_2d, pca, output_path) + _plot_latents_by_length(lengths, latents_2d, pca, output_path) + - # Visualize latents and color by path length +def _plot_latents_by_direction(curves, latents_2d, pca, output_path=None): + # Compute directions and colors for each curve + directions = np.array([curve_principal_direction(c) for c in curves]) + colors = np.array([direction_to_color(d) for d in directions]) + + # Plot + plt.figure(figsize=(8, 6)) + plt.scatter(latents_2d[:, 0], latents_2d[:, 1], c=colors, s=10, alpha=0.8) + plt.xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)") + plt.ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)") + plt.title("Curve Embeddings Colored by Principal Direction") + visualize_result(output_path=output_path) + + +def _plot_latents_by_length(lengths, latents_2d, pca, output_path=None): plt.figure(figsize=(8, 6)) sc = plt.scatter( - latents_2d[:, 0], latents_2d[:, 1], - c=path_lengths, cmap="viridis", s=10, alpha=0.7, - norm=LogNorm(vmin=path_lengths.min(), vmax=path_lengths.max()) + latents_2d[:, 0], + latents_2d[:, 1], + c=lengths, + cmap="viridis", + s=10, + alpha=0.7, + norm=LogNorm(vmin=lengths.min(), vmax=lengths.max()), ) plt.colorbar(sc, label="Path length (μm)") - plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)') - plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)') + plt.xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)") + plt.ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)") plt.title("PCA of curve embeddings") visualize_result(output_path=output_path) - + # --- Helpers --- def add_line(p, color=None, label=None): plt.axvline(p, color=color, linestyle="--", label=label, zorder=2) +def curve_principal_direction(curve): + curve_pca = PCA(n_components=1) + curve_pca.fit(curve) + direction = curve_pca.components_[0] + if direction[2] < 0: + direction = -direction + return direction / np.linalg.norm(direction) + + +def direction_to_color(direction): + x, y, z = direction + azimuth = np.arctan2(y, x) + hue = (azimuth + np.pi) / (2 * np.pi) + polar = np.arccos(np.clip(z, 0, 1)) + saturation = polar / (np.pi / 2) + value = 1.0 + return hsv_to_rgb(hue, saturation, value) + + def visualize_result(output_path=None): plt.tight_layout() if output_path: diff --git a/src/neuron_proofreader/machine_learning/vision_models.py b/src/neuron_proofreader/machine_learning/vision_models.py index 21b6760d..d514df69 100644 --- a/src/neuron_proofreader/machine_learning/vision_models.py +++ b/src/neuron_proofreader/machine_learning/vision_models.py @@ -9,7 +9,6 @@ """ -# from neurobase.finetune import finetune_model from einops import rearrange import torch @@ -131,36 +130,6 @@ def forward(self, x): # --- Transformers --- -class MAE3D(nn.Module): - - def __init__(self, checkpoint_path, model_config): - # Call parent class - super().__init__() - - # Load model - full_model = finetune_model( - checkpoint_path=checkpoint_path, - model_config=model_config, - task_head_config="binary_classifier", - freeze_encoder=True, - ) - - # Instance attributes - self.encoder = full_model.encoder - self.output = FeedForwardNet(384, 1, 3) - - def forward(self, x): - latent0 = self.encoder(x[:, 0:1, ...]) - latent1 = self.encoder(x[:, 1:2, ...]) - - x0 = latent0["latents"][:, 0, :] - x1 = latent1["latents"][:, 0, :] - - x = torch.cat((x0, x1), dim=1) - x = self.output(x) - return x - - class ViT3D(nn.Module): """ A class that implements a 3D Vision transformer. diff --git a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py index d7c7bb8f..c85ae38c 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py @@ -442,7 +442,7 @@ def __iter__(self): # Split into batches upfront batch_idx_groups = [ - idxs[start: min(start + self.batch_size, len(idxs))] + idxs[start : min(start + self.batch_size, len(idxs))] for start in range(0, len(idxs), self.batch_size) ] diff --git a/src/neuron_proofreader/merge_proofreading/search_datasets.py b/src/neuron_proofreader/merge_proofreading/search_datasets.py index ad761584..d2666dac 100644 --- a/src/neuron_proofreader/merge_proofreading/search_datasets.py +++ b/src/neuron_proofreader/merge_proofreading/search_datasets.py @@ -23,7 +23,7 @@ DetectionBatchLoader, DetectionPatchLoader, ) -from neuron_proofreader.utils import img_util, ml_util, util +from neuron_proofreader.utils import img_util, util # --- Datasets --- diff --git a/src/neuron_proofreader/utils/geometry_util.py b/src/neuron_proofreader/utils/geometry_util.py index c8b03acf..25baa13a 100644 --- a/src/neuron_proofreader/utils/geometry_util.py +++ b/src/neuron_proofreader/utils/geometry_util.py @@ -16,9 +16,28 @@ import networkx as nx import numpy as np +import torch # --- Curve Utils --- +def compute_length(curve): + """ + Computes the Euclidean length of the given curve. + + Parameters + ---------- + curve : numpy.ndarray + Array of points that form an n-d curve. + + Returns + ------- + float + Euclidean length of the given curve. + """ + diffs = curve[1:] - curve[:-1] + return np.linalg.norm(diffs**2, axis=1).sum() + + def fit_spline_1d(pts, k=3, s=None): """ Fits a spline to 1D curve. @@ -142,6 +161,42 @@ def resample_curve_3d(pts, n_pts=None, s=None): return pts +def reconstruct_diffs(diffs): + """ + Reconstructs a curve from a sequence of offset vectors. + + Parameters + ---------- + diffs : numpy.ndarray or torch.Tensor + Array representing the differences between consecutive points. + + Returns + ------- + numpy.ndarray + Reconstructed curve. + """ + if isinstance(diffs, torch.Tensor): + start = torch.zeros(1, 3, device=diffs.device, dtype=diffs.dtype) + return torch.cat([start, start + torch.cumsum(diffs, dim=0)], dim=0) + else: + start = np.zeros((1, 3)) + return np.concatenate( + [start, start + np.cumsum(diffs, axis=0)], axis=0 + ) + + +def rmse(curve1, curve2): + """ + Computes Root Mean Squared Error (RMSE) between two curves. + + Parameters + ---------- + curve1 : numpy.ndarray + + """ + return np.sqrt(np.mean(np.sum((curve1 - curve2) ** 2, axis=1))) + + # --- Fragment Filtering --- def remove_doubles(graph, max_cable_length): """