diff --git a/src/neuron_proofreader/geometric_learning/curve_datamodules.py b/src/neuron_proofreader/geometric_learning/curve_datamodules.py index e957f44..5492e3b 100644 --- a/src/neuron_proofreader/geometric_learning/curve_datamodules.py +++ b/src/neuron_proofreader/geometric_learning/curve_datamodules.py @@ -28,11 +28,13 @@ def __init__( swcs_path, graph_config=None, max_length=np.inf, + segment_len=10, transform=None, ): # Instance attributes self.brain_id = brain_id self.max_length = max_length + self.segment_len = segment_len self.transform = transform # Core data structures @@ -53,7 +55,8 @@ def load_skeletons(self, config, swcs_path): def get_valid_paths(self): paths = list() for p in self.irreducible_paths(): - if self.path_length(p) < self.max_length: + length = self.path_length(p) + if length < self.max_length and len(p) > self.segment_len: paths.append(p) return paths diff --git a/src/neuron_proofreader/geometric_learning/curve_transformer.py b/src/neuron_proofreader/geometric_learning/curve_transformer.py index 533558d..80abb70 100644 --- a/src/neuron_proofreader/geometric_learning/curve_transformer.py +++ b/src/neuron_proofreader/geometric_learning/curve_transformer.py @@ -271,13 +271,9 @@ def forward(self, z, n_segments=None): if self.mlp_type == "siren": # Raw scalar t concatenated with z — SIREN encodes frequency internally - t_exp = ( - t.unsqueeze(-1).unsqueeze(0).expand(B, -1, -1) - ) # (B, n_seg, 1) - z_exp = z.unsqueeze(1).expand( - -1, n_segments, -1 - ) # (B, n_seg, latent_dim) - x = torch.cat([z_exp, t_exp], dim=-1) # (B, n_seg, latent_dim+1) + t_exp = t.unsqueeze(-1).unsqueeze(0).expand(B, -1, -1) + z_exp = z.unsqueeze(1).expand(-1, n_segments, -1) + x = torch.cat([z_exp, t_exp], dim=-1) out = self.mlp(x) elif self.mlp_type == "residual": @@ -414,19 +410,19 @@ def forward(self, z, n_segments=None): outputs = [] T_prev = torch.zeros(B, d_seg, device=z.device) # T_{-1}: start token for s in range(n_segments): - pe_s = pe[s].unsqueeze(0).expand(B, -1) # (B, 2*n_freq) + # Create GRU input + pe_s = pe[s].unsqueeze(0).expand(B, -1) + x_s = torch.cat([T_prev, pe_s, z], dim=-1).unsqueeze(1) - # Concatenate T_{i-1}, pe(t_i), and z at every step - x_s = torch.cat([T_prev, pe_s, z], dim=-1).unsqueeze( - 1 - ) # (B, 1, d_input) - out, h = self.gru(x_s, h) # (B, 1, d_hidden) - T_i = self.output_proj(out.squeeze(1)) # (B, d_seg) + # Predict next token and update hidden state + out, h = self.gru(x_s, h) + T_i = self.output_proj(out.squeeze(1)) + # Update previous token outputs.append(T_i) T_prev = T_i - out = torch.stack(outputs, dim=1) # (B, n_segments, d_seg) + out = torch.stack(outputs, dim=1) return out.reshape(B, n_segments * self.segment_len, 3) @@ -473,12 +469,12 @@ def __init__( ) self.decoder = decoder - def forward(self, offsets, token_mask): + def forward(self, diffs, mask=None): """ Parameters ---------- - offsets : torch.Tensor - Shape (B, N, 3), normalized to the unit sphere, offsets[:, 0] == 0. + diffs : torch.Tensor + Shape (B, N, 3), normalized to the unit sphere, diffs[:, 0] == 0. Returns ------- @@ -487,13 +483,19 @@ def forward(self, offsets, token_mask): z : torch.Tensor Latent vector of shape (B, latent_dim). """ - z, encoder_tokens = self.encoder(offsets, token_mask) - n_segments = offsets.shape[1] // self.decoder.segment_len + z, encoder_tokens = self.encoder(diffs, mask) + if mask is not None: + valid_lengths = (~mask).sum(dim=1) # (B,) + n_segments = ( + valid_lengths.min() // self.decoder.segment_len + ).item() + else: + n_segments = diffs.shape[1] // self.decoder.segment_len reconstruction = self.decoder(z, n_segments=n_segments) return reconstruction, z - def encode(self, offsets): - z, _ = self.encoder(offsets) + def encode(self, diffs): + z, _ = self.encoder(diffs) return z # --- Helpers --- diff --git a/src/neuron_proofreader/geometric_learning/curve_visualization.py b/src/neuron_proofreader/geometric_learning/curve_visualization.py new file mode 100644 index 0000000..1a4b16a --- /dev/null +++ b/src/neuron_proofreader/geometric_learning/curve_visualization.py @@ -0,0 +1,179 @@ +""" +Created on Mon June 8 17:00:00 2026 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Code for visualizing 3D space curves and their embeddings. + +""" + +from colorsys import hsv_to_rgb +from matplotlib.colors import LogNorm +from sklearn.decomposition import PCA + +import matplotlib.pyplot as plt +import numpy as np +import plotly.graph_objects as go + +from neuron_proofreader.utils import geometry_util + + +# --- Plot Curves --- +def plot_curves(curve1, curve2, name1=None, name2=None): + fig = go.Figure( + data=[ + go.Scatter3d( + x=curve1[:, 0], + y=curve1[:, 1], + z=curve1[:, 2], + mode="lines", + name=name1, + line=dict(width=5, color="blue"), + ), + go.Scatter3d( + x=curve2[:, 0], + y=curve2[:, 1], + z=curve2[:, 2], + mode="lines", + name=name2, + line=dict(width=5, color="green"), + ), + go.Scatter3d( + x=[0], + y=[0], + z=[0], + mode="markers", + name="Origin", + marker=dict(size=3, color="red"), + ), + ] + ) + fig.update_layout( + scene=dict(xaxis_title="x", yaxis_title="y", zaxis_title="z") + ) + fig.show() + + +def plot_length_distribution(dataset_collection, title=None, output_path=None): + # Compute path length stats + lengths = dataset_collection.examples_df["length"] + p50 = round(np.percentile(lengths, 50), 2) + p99 = round(np.percentile(lengths, 99.9), 2) + thr_lengths = [l for l in lengths if l < p99] + + # Plot path lengths + plt.figure(figsize=(8, 5)) + plt.hist(thr_lengths, bins=50, edgecolor="white", linewidth=0.5, zorder=2) + add_line(p50, color="r", label=f"50th perc = {p50}") + add_line(p99, color="g", label=f"99.9th perc = {p99}") + + # Plot labels + plt.grid(axis="y", color="lightgrey", linewidth=0.5) + plt.legend(loc="upper right") + plt.title(title, fontsize=13) + plt.xlabel("Path Length (μm)", fontsize=12) + plt.ylabel("Count", fontsize=12) + plt.yscale("log") + + plt.figtext( + 0.01, + -0.02, + "Note: Path lengths thresholded at 99.9th percentile in this plot", + fontsize=9, + color="tab:grey", + va="bottom", + ) + plt.subplots_adjust(bottom=0.8) + visualize_result(output_path=output_path) + + +# --- Plot Curve Embeddings --- +def plot_error_vs_length(lengths, rmse_results, output_path=None): + # Set colors + norm = LogNorm(vmin=lengths.min(), vmax=lengths.max()) + colors = plt.cm.viridis(norm(lengths)) + + # Plot + plt.figure(figsize=(8, 6)) + plt.scatter(lengths, rmse_results, c=colors, s=10, alpha=0.8) + plt.xscale("log") + plt.yscale("log") + plt.xlabel("Path Length (μm)", fontsize=12) + plt.ylabel("RMSE", fontsize=12) + visualize_result(output_path=output_path) + + +def plot_latents_by_pca(curves, latents, output_path=None): + # PCA of latents + pca = PCA(n_components=2) + latents_2d = pca.fit_transform(latents) + lengths = np.array([geometry_util.compute_length(c) for c in curves]) + + # Visualize results + _plot_latents_by_direction(curves, latents_2d, pca, output_path) + _plot_latents_by_length(lengths, latents_2d, pca, output_path) + + +def _plot_latents_by_direction(curves, latents_2d, pca, output_path=None): + # Compute directions and colors for each curve + directions = np.array([curve_principal_direction(c) for c in curves]) + colors = np.array([direction_to_color(d) for d in directions]) + + # Plot + plt.figure(figsize=(8, 6)) + plt.scatter(latents_2d[:, 0], latents_2d[:, 1], c=colors, s=10, alpha=0.8) + plt.xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)") + plt.ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)") + plt.title("Curve Embeddings Colored by Principal Direction") + visualize_result(output_path=output_path) + + +def _plot_latents_by_length(lengths, latents_2d, pca, output_path=None): + plt.figure(figsize=(8, 6)) + sc = plt.scatter( + latents_2d[:, 0], + latents_2d[:, 1], + c=lengths, + cmap="viridis", + s=10, + alpha=0.7, + norm=LogNorm(vmin=lengths.min(), vmax=lengths.max()), + ) + plt.colorbar(sc, label="Path length (μm)") + plt.xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)") + plt.ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)") + plt.title("PCA of curve embeddings") + visualize_result(output_path=output_path) + + +# --- Helpers --- +def add_line(p, color=None, label=None): + plt.axvline(p, color=color, linestyle="--", label=label, zorder=2) + + +def curve_principal_direction(curve): + curve_pca = PCA(n_components=1) + curve_pca.fit(curve) + direction = curve_pca.components_[0] + if direction[2] < 0: + direction = -direction + return direction / np.linalg.norm(direction) + + +def direction_to_color(direction): + x, y, z = direction + azimuth = np.arctan2(y, x) + hue = (azimuth + np.pi) / (2 * np.pi) + polar = np.arccos(np.clip(z, 0, 1)) + saturation = polar / (np.pi / 2) + value = 1.0 + return hsv_to_rgb(hue, saturation, value) + + +def visualize_result(output_path=None): + plt.tight_layout() + if output_path: + plt.savefig(output_path, dpi=300, bbox_inches="tight") + else: + plt.show() diff --git a/src/neuron_proofreader/machine_learning/vision_models.py b/src/neuron_proofreader/machine_learning/vision_models.py index 21b6760..d514df6 100644 --- a/src/neuron_proofreader/machine_learning/vision_models.py +++ b/src/neuron_proofreader/machine_learning/vision_models.py @@ -9,7 +9,6 @@ """ -# from neurobase.finetune import finetune_model from einops import rearrange import torch @@ -131,36 +130,6 @@ def forward(self, x): # --- Transformers --- -class MAE3D(nn.Module): - - def __init__(self, checkpoint_path, model_config): - # Call parent class - super().__init__() - - # Load model - full_model = finetune_model( - checkpoint_path=checkpoint_path, - model_config=model_config, - task_head_config="binary_classifier", - freeze_encoder=True, - ) - - # Instance attributes - self.encoder = full_model.encoder - self.output = FeedForwardNet(384, 1, 3) - - def forward(self, x): - latent0 = self.encoder(x[:, 0:1, ...]) - latent1 = self.encoder(x[:, 1:2, ...]) - - x0 = latent0["latents"][:, 0, :] - x1 = latent1["latents"][:, 0, :] - - x = torch.cat((x0, x1), dim=1) - x = self.output(x) - return x - - class ViT3D(nn.Module): """ A class that implements a 3D Vision transformer. diff --git a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py index d7c7bb8..c85ae38 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py @@ -442,7 +442,7 @@ def __iter__(self): # Split into batches upfront batch_idx_groups = [ - idxs[start: min(start + self.batch_size, len(idxs))] + idxs[start : min(start + self.batch_size, len(idxs))] for start in range(0, len(idxs), self.batch_size) ] diff --git a/src/neuron_proofreader/merge_proofreading/search_datasets.py b/src/neuron_proofreader/merge_proofreading/search_datasets.py index ad76158..d2666da 100644 --- a/src/neuron_proofreader/merge_proofreading/search_datasets.py +++ b/src/neuron_proofreader/merge_proofreading/search_datasets.py @@ -23,7 +23,7 @@ DetectionBatchLoader, DetectionPatchLoader, ) -from neuron_proofreader.utils import img_util, ml_util, util +from neuron_proofreader.utils import img_util, util # --- Datasets --- diff --git a/src/neuron_proofreader/utils/geometry_util.py b/src/neuron_proofreader/utils/geometry_util.py index c8b03ac..3ba34b3 100644 --- a/src/neuron_proofreader/utils/geometry_util.py +++ b/src/neuron_proofreader/utils/geometry_util.py @@ -16,9 +16,28 @@ import networkx as nx import numpy as np +import torch # --- Curve Utils --- +def compute_length(curve): + """ + Computes the Euclidean length of the given curve. + + Parameters + ---------- + curve : numpy.ndarray + Array of points that form an n-d curve. + + Returns + ------- + float + Euclidean length of the given curve. + """ + diffs = curve[1:] - curve[:-1] + return np.linalg.norm(diffs**2, axis=1).sum() + + def fit_spline_1d(pts, k=3, s=None): """ Fits a spline to 1D curve. @@ -142,6 +161,42 @@ def resample_curve_3d(pts, n_pts=None, s=None): return pts +def reconstruct_diffs(diffs): + """ + Reconstructs a curve from a sequence of offset vectors. + + Parameters + ---------- + diffs : numpy.ndarray or torch.Tensor + Array representing the differences between consecutive points. + + Returns + ------- + numpy.ndarray + Reconstructed curve. + """ + if isinstance(diffs, torch.Tensor): + start = torch.zeros(1, 3, device=diffs.device, dtype=diffs.dtype) + return torch.cat([start, start + torch.cumsum(diffs, dim=0)], dim=0) + else: + start = np.zeros((1, 3)) + return np.concatenate( + [start, start + np.cumsum(diffs, axis=0)], axis=0 + ) + + +def rmse(curve1, curve2): + """ + Computes Root Mean Squared Error (RMSE) between two curves. + + Parameters + ---------- + curve1 : numpy.ndarray + + """ + return np.sqrt(np.mean(np.sum((curve1 - curve2) ** 2, axis=1))) + + # --- Fragment Filtering --- def remove_doubles(graph, max_cable_length): """ @@ -190,14 +245,14 @@ def remove_doubles(graph, max_cable_length): def is_double(graph, nodes): """ Determines if the connected component corresponding to "nodes" is a double - another connected component. + of another connected component. Paramters --------- graph : SkeletonGraph Graph to be searched. nodes : List[int] - Nodes that correspond to a single connected component. + Nodes corresponding to a single connected component. Returns -------