diff --git a/src/neuron_proofreader/geometric_learning/curve_transformer.py b/src/neuron_proofreader/geometric_learning/curve_transformer.py index 45cc0daf..533558dc 100644 --- a/src/neuron_proofreader/geometric_learning/curve_transformer.py +++ b/src/neuron_proofreader/geometric_learning/curve_transformer.py @@ -149,113 +149,293 @@ def forward(self, offsets, mask=None): return z, tokens -class CurveDecoder(nn.Module): +class INRDecoder(nn.Module): """ - Transformer decoder that reconstructs a 3D space curve from a latent - vector and the encoder's token representations. Positional queries are - sinusoidally encoded over arc position and biased by the global latent. - The output resolution can differ from the encoder input, allowing - decoding at arbitrary granularity. + Implicit neural representation decoder. Models the curve as F(z, t) -> + offsets where t in [0, 1] is the normalized arc position of each segment. + Supports three MLP variants: plain, residual (ResNet-style), and SIREN. """ def __init__( self, - n_points=100, segment_len=10, - d_token=64, - n_heads=4, - n_layers=4, - d_ff=64, latent_dim=32, + d_hidden=256, + n_layers=4, + n_frequencies=16, + mlp_type="residual", dropout=0.1, ): """ Parameters ---------- - n_points : int - Default number of output curve points. segment_len : int - Number of points per segment token (must match encoder). - d_token : int - Dimension of each token throughout the transformer. - n_heads : int - Number of attention heads. - n_layers : int - Number of transformer decoder layers. - d_ff : int - Feed-forward hidden dimension. + Number of points per segment. latent_dim : int Dimension of the input latent vector. + d_hidden : int + Hidden dimension of the MLP. + n_layers : int + Number of hidden layers. + n_frequencies : int + Number of Fourier frequency bands for positional encoding of t. + Not used when mlp_type is 'siren' (SIREN handles PE internally). + mlp_type : str + One of 'plain', 'residual', or 'siren'. dropout : float - Dropout probability. + Dropout probability. Not used in SIREN. """ super().__init__() self.segment_len = segment_len - self.n_segments = n_points // segment_len + self.n_frequencies = n_frequencies + self.mlp_type = mlp_type + + d_out = segment_len * 3 + + if mlp_type == "siren": + # SIREN handles position encoding internally via sin activations — + # just concatenate raw t (scalar) with z + d_input = latent_dim + 1 + layers = [SirenLayer(d_input, d_hidden, is_first=True)] + for _ in range(n_layers - 1): + layers.append(SirenLayer(d_hidden, d_hidden)) + layers.append(nn.Linear(d_hidden, d_out)) + self.mlp = nn.Sequential(*layers) + + elif mlp_type == "residual": + d_input = latent_dim + 2 * n_frequencies + self.input_proj = nn.Sequential( + nn.Linear(d_input, d_hidden), + nn.LayerNorm(d_hidden), + nn.GELU(), + ) + self.blocks = nn.ModuleList( + [ResidualBlock(d_hidden, dropout) for _ in range(n_layers)] + ) + self.output_proj = nn.Linear(d_hidden, d_out) + + elif mlp_type == "plain": + d_input = latent_dim + 2 * n_frequencies + layers = [ + nn.Linear(d_input, d_hidden), + nn.GELU(), + nn.Dropout(dropout), + ] + for _ in range(n_layers - 1): + layers += [ + nn.Linear(d_hidden, d_hidden), + nn.GELU(), + nn.Dropout(dropout), + ] + layers.append(nn.Linear(d_hidden, d_out)) + self.mlp = nn.Sequential(*layers) - # Project latent to d_token to bias the positional queries - self.latent_proj = nn.Linear(latent_dim, d_token) + else: + raise ValueError(f"Unknown mlp_type: {mlp_type}") - decoder_layer = nn.TransformerDecoderLayer( - d_model=d_token, - nhead=n_heads, - dim_feedforward=d_ff, - dropout=dropout, + def positional_encoding(self, t): + """ + Fourier positional encoding for scalar arc positions. + + Parameters + ---------- + t : torch.Tensor + Arc positions of shape (n_segments,) in [0, 1]. + + Returns + ------- + torch.Tensor + Shape (n_segments, 2 * n_frequencies). + """ + freqs = 2 ** torch.arange(self.n_frequencies, device=t.device).float() + x = t.unsqueeze(-1) * freqs * torch.pi + return torch.cat([torch.sin(x), torch.cos(x)], dim=-1) + + def forward(self, z, n_segments=None): + """ + Parameters + ---------- + z : torch.Tensor + Latent vector of shape (B, latent_dim). + n_segments : int, optional + Number of output segments. + + Returns + ------- + torch.Tensor + Reconstructed offset sequence of shape (B, n_segments * segment_len, 3). + """ + B = z.shape[0] + n_segments = n_segments or 10 + t = torch.linspace(0, 1, n_segments, device=z.device) + + if self.mlp_type == "siren": + # Raw scalar t concatenated with z — SIREN encodes frequency internally + t_exp = ( + t.unsqueeze(-1).unsqueeze(0).expand(B, -1, -1) + ) # (B, n_seg, 1) + z_exp = z.unsqueeze(1).expand( + -1, n_segments, -1 + ) # (B, n_seg, latent_dim) + x = torch.cat([z_exp, t_exp], dim=-1) # (B, n_seg, latent_dim+1) + out = self.mlp(x) + + elif self.mlp_type == "residual": + pe = self.positional_encoding(t) + z_exp = z.unsqueeze(1).expand(-1, n_segments, -1) + pe_exp = pe.unsqueeze(0).expand(B, -1, -1) + x = torch.cat([z_exp, pe_exp], dim=-1) + x = self.input_proj(x) + for block in self.blocks: + x = block(x) + out = self.output_proj(x) + + else: # plain + pe = self.positional_encoding(t) + z_exp = z.unsqueeze(1).expand(-1, n_segments, -1) + pe_exp = pe.unsqueeze(0).expand(B, -1, -1) + x = torch.cat([z_exp, pe_exp], dim=-1) + out = self.mlp(x) + + return out.reshape(B, n_segments * self.segment_len, 3) + + +class AutoregressiveINRDecoder(nn.Module): + """ + Autoregressive implicit neural representation decoder. At each step, + predicts the i-th segment's offsets conditioned on the global latent z, + the arc position t_i, and the previous segment's predicted offsets T_{i-1}. + + T_i = F(z, t_i | T_{i-1}) + + The GRU hidden state carries long-range context from all previous + segments, while T_{i-1} provides an explicit local continuity signal. + """ + + def __init__( + self, + segment_len=10, + latent_dim=32, + d_hidden=256, + n_layers=4, + n_frequencies=16, + dropout=0.1, + ): + """ + Parameters + ---------- + segment_len : int + Number of points per segment. + latent_dim : int + Dimension of the input latent vector z. + d_hidden : int + GRU hidden dimension. + n_layers : int + Number of GRU layers. + n_frequencies : int + Number of Fourier frequency bands for positional encoding of t. + dropout : float + Dropout probability. + """ + super().__init__() + self.segment_len = segment_len + self.n_frequencies = n_frequencies + d_seg = segment_len * 3 + + # Seed the GRU hidden state from z + self.latent_proj = nn.Linear(latent_dim, d_hidden) + + # Input at each step: [T_{i-1} | pe(t_i) | z] + # Including z directly at every step (alongside the hidden state) lets + # the model re-attend to the global code at each position, rather than + # relying purely on it surviving through the hidden state + d_input = d_seg + 2 * n_frequencies + latent_dim + self.gru = nn.GRU( + input_size=d_input, + hidden_size=d_hidden, + num_layers=n_layers, batch_first=True, + dropout=dropout if n_layers > 1 else 0, ) - self.transformer = nn.TransformerDecoder( - decoder_layer, num_layers=n_layers + self.output_proj = nn.Sequential( + nn.LayerNorm(d_hidden), + nn.Linear(d_hidden, d_seg), ) - self.to_points = nn.Sequential( - nn.LayerNorm(d_token), - nn.Linear(d_token, segment_len * 3), - ) + def positional_encoding(self, t): + """ + Fourier positional encoding for scalar arc positions. + + Parameters + ---------- + t : torch.Tensor + Arc positions of shape (n_segments,) in [0, 1]. + + Returns + ------- + torch.Tensor + Shape (n_segments, 2 * n_frequencies). + """ + freqs = 2 ** torch.arange(self.n_frequencies, device=t.device).float() + x = t.unsqueeze(-1) * freqs * torch.pi + return torch.cat([torch.sin(x), torch.cos(x)], dim=-1) - def forward(self, z, encoder_tokens, encoder_mask=None, n_segments=None): + def forward(self, z, n_segments=None): """ Parameters ---------- z : torch.Tensor Latent vector of shape (B, latent_dim). - encoder_tokens : torch.Tensor - Per-token encoder outputs of shape (B, n_segments + 2, d_token). - encoder_mask : torch.Tensor, optional - Shape (B, n_segments + 2), True where padding. Passed as - memory_key_padding_mask to cross-attention. Default is None. n_segments : int, optional - Number of output segments. Inferred from encoder tokens if not - provided. + Number of output segments to decode. Returns ------- - curve : torch.Tensor - Reconstructed curve of shape (B, n_segments * segment_len, 3). + torch.Tensor + Reconstructed offset sequence of shape (B, n_segments * segment_len, 3). """ B = z.shape[0] - d_token = encoder_tokens.shape[2] - n_segments = n_segments or self.n_segments + n_segments = n_segments or 10 + d_seg = self.segment_len * 3 + + # Seed hidden state from z: (n_layers, B, d_hidden) + h = ( + self.latent_proj(z) + .unsqueeze(0) + .expand(self.gru.num_layers, -1, -1) + .contiguous() + ) - # Sinusoidal queries over arc position, biased by global latent - pe = sinusoidal_encoding(n_segments, d_token, encoder_tokens.device) - latent = self.latent_proj(z).unsqueeze(1) # (B, 1, d_token) - queries = pe.expand(B, -1, -1) + latent # (B, n_seg, d_token) + # Arc positions and Fourier encodings: (n_segments, 2 * n_freq) + t = torch.linspace(0, 1, n_segments, device=z.device) + pe = self.positional_encoding(t) - out = self.transformer( - queries, - encoder_tokens, - memory_key_padding_mask=encoder_mask, - ) # (B, n_seg, d_token) + # Autoregressive loop + outputs = [] + T_prev = torch.zeros(B, d_seg, device=z.device) # T_{-1}: start token + for s in range(n_segments): + pe_s = pe[s].unsqueeze(0).expand(B, -1) # (B, 2*n_freq) - segments = self.to_points(out) # (B, n_seg, seg_len*3) - offsets = segments.reshape(B, n_segments * self.segment_len, 3) - return offsets + # Concatenate T_{i-1}, pe(t_i), and z at every step + x_s = torch.cat([T_prev, pe_s, z], dim=-1).unsqueeze( + 1 + ) # (B, 1, d_input) + out, h = self.gru(x_s, h) # (B, 1, d_hidden) + T_i = self.output_proj(out.squeeze(1)) # (B, d_seg) + + outputs.append(T_i) + T_prev = T_i + + out = torch.stack(outputs, dim=1) # (B, n_segments, d_seg) + return out.reshape(B, n_segments * self.segment_len, 3) class CurveAutoencoder(nn.Module): def __init__( self, + decoder, + decoder_name=None, n_points=100, segment_len=10, d_token=64, @@ -270,6 +450,7 @@ def __init__( # Config self.config = { + "decoder_name": decoder_name, "n_points": n_points, "segment_len": segment_len, "d_token": d_token, @@ -290,16 +471,7 @@ def __init__( latent_dim=latent_dim, dropout=dropout, ) - self.decoder = CurveDecoder( - n_points=n_points, - segment_len=segment_len, - d_token=d_token, - n_heads=n_heads, - n_layers=n_layers, - d_ff=d_ff, - latent_dim=latent_dim, - dropout=dropout, - ) + self.decoder = decoder def forward(self, offsets, token_mask): """ @@ -317,7 +489,7 @@ def forward(self, offsets, token_mask): """ z, encoder_tokens = self.encoder(offsets, token_mask) n_segments = offsets.shape[1] // self.decoder.segment_len - reconstruction = self.decoder(z, encoder_tokens, n_segments=n_segments) + reconstruction = self.decoder(z, n_segments=n_segments) return reconstruction, z def encode(self, offsets): @@ -337,6 +509,50 @@ def load(cls, path): # --- Helpers --- +class ResidualBlock(nn.Module): + def __init__(self, d_hidden, dropout=0.1): + super().__init__() + self.block = nn.Sequential( + nn.Linear(d_hidden, d_hidden), + nn.LayerNorm(d_hidden), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(d_hidden, d_hidden), + nn.LayerNorm(d_hidden), + ) + self.act = nn.GELU() + + def forward(self, x): + return self.act(x + self.block(x)) + + +class SirenLayer(nn.Module): + def __init__(self, d_in, d_out, omega=30.0, is_first=False): + """ + Parameters + ---------- + omega : float + Frequency scaling factor. Default 30 following the original paper. + is_first : bool + First layer uses a different initialization scheme. + """ + super().__init__() + self.omega = omega + self.linear = nn.Linear(d_in, d_out) + self._init_weights(is_first, d_in) + + def _init_weights(self, is_first, d_in): + with torch.no_grad(): + if is_first: + bound = 1 / d_in + else: + bound = np.sqrt(6 / d_in) / self.omega + self.linear.weight.uniform_(-bound, bound) + + def forward(self, x): + return torch.sin(self.omega * self.linear(x)) + + def sinusoidal_encoding(n_tokens, d_token, device): """ Sinusoidal positional encoding over normalised arc position [0, 1]. diff --git a/src/neuron_proofreader/machine_learning/image_dataloader.py b/src/neuron_proofreader/machine_learning/image_dataloader.py index b1fd6a5c..1a965c09 100644 --- a/src/neuron_proofreader/machine_learning/image_dataloader.py +++ b/src/neuron_proofreader/machine_learning/image_dataloader.py @@ -19,11 +19,11 @@ ) from neuron_proofreader.utils import geometry_util, img_util, util - # ---------------------------------------------------------------------------- # Image Class # ---------------------------------------------------------------------------- + class TensorStoreImage: """ Class that reads images with the TensorStore library. @@ -100,6 +100,7 @@ def shape(self): # PatchLoader Class # ---------------------------------------------------------------------------- + class PatchLoader(ABC): """ A class for reading image patches and generating segment masks. @@ -176,7 +177,7 @@ def create_mask(self, center, shape, node): # Annotate mask mask = np.zeros(shape, dtype=np.float32) - #self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP + # self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP self.annotate_fragment(mask, subgraph, offset, fill_val=1) return mask @@ -231,6 +232,7 @@ def stack(img, mask): # PatchLoader Subclasses # ---------------------------------------------------------------------------- + class DetectionPatchLoader(PatchLoader): # --- Implementation of Abstract Inferface --- diff --git a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py index c85ae38c..d7c7bb8f 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py @@ -442,7 +442,7 @@ def __iter__(self): # Split into batches upfront batch_idx_groups = [ - idxs[start : min(start + self.batch_size, len(idxs))] + idxs[start: min(start + self.batch_size, len(idxs))] for start in range(0, len(idxs), self.batch_size) ] diff --git a/src/neuron_proofreader/merge_proofreading/merge_detection.py b/src/neuron_proofreader/merge_proofreading/merge_detection.py index c0736a06..81e4fb0a 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_detection.py +++ b/src/neuron_proofreader/merge_proofreading/merge_detection.py @@ -176,7 +176,9 @@ def save_results( ): self.save_sites(output_dir) if save_fragments: - self.dataset.graph.node_radius = 10 * np.maximum(self.node_preds, 0.1) + self.dataset.graph.node_radius = 10 * np.maximum( + self.node_preds, 0.1 + ) fragments_path = os.path.join(output_dir, "fragments.zip") self.dataset.to_zipped_swcs(fragments_path, use_radius=True) diff --git a/src/neuron_proofreader/merge_proofreading/search_datasets.py b/src/neuron_proofreader/merge_proofreading/search_datasets.py index 52833cbf..ad761584 100644 --- a/src/neuron_proofreader/merge_proofreading/search_datasets.py +++ b/src/neuron_proofreader/merge_proofreading/search_datasets.py @@ -15,14 +15,10 @@ from threading import Thread from torch.utils.data import IterableDataset -import itertools import networkx as nx import numpy as np import torch -from neuron_proofreader.machine_learning.point_cloud_models import ( - subgraph_to_point_cloud, -) from neuron_proofreader.machine_learning.image_dataloader import ( DetectionBatchLoader, DetectionPatchLoader, @@ -64,31 +60,33 @@ def __init__( def __iter__(self): patch_queue = Queue(maxsize=self.prefetch) sentinel = object() - + def producer(): sites = self._all_sites() with ThreadPoolExecutor(max_workers=self.prefetch) as executor: futures = {} - + def fill(): while len(futures) < self.prefetch: try: site = next(sites) - futures[executor.submit(self.patch_loader, site)] = site + futures[ + executor.submit(self.patch_loader, site) + ] = site except StopIteration: break - + fill() while futures: done, _ = wait(futures, return_when=FIRST_COMPLETED) for f in done: patch_queue.put((futures.pop(f), f.result())) fill() - + patch_queue.put(sentinel) - + Thread(target=producer, daemon=True).start() - + while True: item = patch_queue.get() if item is sentinel: diff --git a/src/neuron_proofreader/proofreading_pipeline.py b/src/neuron_proofreader/proofreading_pipeline.py index 0cb8f8ab..49b1bec3 100644 --- a/src/neuron_proofreader/proofreading_pipeline.py +++ b/src/neuron_proofreader/proofreading_pipeline.py @@ -146,7 +146,9 @@ def split_proofreading( def connect_soma_fragments(self, max_dist=25): self.step_cnt += 1 - self.log(f"\nStep {self.step_cnt}: Connect Soma Fragments with dist={max_dist}") + self.log( + f"\nStep {self.step_cnt}: Connect Soma Fragments with dist={max_dist}" + ) summary = self.graph.connect_soma_fragments(max_dist=max_dist) self.log(summary) @@ -173,7 +175,7 @@ def merge_proofreading(self, mode): swc_util.write_points( zip_path, merge_sites, color=color, prefix="merge_site", radius=10 ) - + # --- Helpers --- def log(self, txt): """ diff --git a/src/neuron_proofreader/skeleton_graph.py b/src/neuron_proofreader/skeleton_graph.py index e26a3f7c..ed649949 100644 --- a/src/neuron_proofreader/skeleton_graph.py +++ b/src/neuron_proofreader/skeleton_graph.py @@ -535,6 +535,7 @@ def component_to_zipped_swc(self, zipfile, root, use_radius=False): Indication of whether to preserve radii of nodes or use default radius of 2μm. Default is False. """ + # Subroutines def write_entry(node, parent): """ diff --git a/src/neuron_proofreader/utils/swc_util.py b/src/neuron_proofreader/utils/swc_util.py index 920af307..fc87b644 100644 --- a/src/neuron_proofreader/utils/swc_util.py +++ b/src/neuron_proofreader/utils/swc_util.py @@ -202,7 +202,9 @@ def read_zips(self, zip_paths, read_fn): pbar = self.manual_progress_bar(len(zip_paths)) with ProcessPoolExecutor() as executor: # Assign processes - futures = {executor.submit(read_fn, path) for path in zip_paths[0:500]} # TEMP + futures = { + executor.submit(read_fn, path) for path in zip_paths[0:500] + } # TEMP # Store results swc_dicts = deque() diff --git a/src/neuron_proofreader/utils/util.py b/src/neuron_proofreader/utils/util.py index df623297..4056cbfb 100644 --- a/src/neuron_proofreader/utils/util.py +++ b/src/neuron_proofreader/utils/util.py @@ -401,7 +401,7 @@ def parse_cloud_path(path): """ # Remove s3:// or gs:// if present if path.startswith("s3://") or path.startswith("gs://"): - path = path[len("s3://") :] + path = path[len("s3://"):] # Split path parts = path.split("/", 1)