Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ def __init__(
swcs_path,
graph_config=None,
max_length=np.inf,
segment_len=10,
transform=None,
):
# Instance attributes
self.brain_id = brain_id
self.max_length = max_length
self.segment_len = segment_len
self.transform = transform

# Core data structures
Expand All @@ -53,7 +55,8 @@ def load_skeletons(self, config, swcs_path):
def get_valid_paths(self):
paths = list()
for p in self.irreducible_paths():
if self.path_length(p) < self.max_length:
length = self.path_length(p)
if length < self.max_length and len(p) > self.segment_len:
paths.append(p)
return paths

Expand Down
46 changes: 24 additions & 22 deletions src/neuron_proofreader/geometric_learning/curve_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,9 @@ def forward(self, z, n_segments=None):

if self.mlp_type == "siren":
# Raw scalar t concatenated with z — SIREN encodes frequency internally
t_exp = (
t.unsqueeze(-1).unsqueeze(0).expand(B, -1, -1)
) # (B, n_seg, 1)
z_exp = z.unsqueeze(1).expand(
-1, n_segments, -1
) # (B, n_seg, latent_dim)
x = torch.cat([z_exp, t_exp], dim=-1) # (B, n_seg, latent_dim+1)
t_exp = t.unsqueeze(-1).unsqueeze(0).expand(B, -1, -1)
z_exp = z.unsqueeze(1).expand(-1, n_segments, -1)
x = torch.cat([z_exp, t_exp], dim=-1)
out = self.mlp(x)

elif self.mlp_type == "residual":
Expand Down Expand Up @@ -414,19 +410,19 @@ def forward(self, z, n_segments=None):
outputs = []
T_prev = torch.zeros(B, d_seg, device=z.device) # T_{-1}: start token
for s in range(n_segments):
pe_s = pe[s].unsqueeze(0).expand(B, -1) # (B, 2*n_freq)
# Create GRU input
pe_s = pe[s].unsqueeze(0).expand(B, -1)
x_s = torch.cat([T_prev, pe_s, z], dim=-1).unsqueeze(1)

# Concatenate T_{i-1}, pe(t_i), and z at every step
x_s = torch.cat([T_prev, pe_s, z], dim=-1).unsqueeze(
1
) # (B, 1, d_input)
out, h = self.gru(x_s, h) # (B, 1, d_hidden)
T_i = self.output_proj(out.squeeze(1)) # (B, d_seg)
# Predict next token and update hidden state
out, h = self.gru(x_s, h)
T_i = self.output_proj(out.squeeze(1))

# Update previous token
outputs.append(T_i)
T_prev = T_i

out = torch.stack(outputs, dim=1) # (B, n_segments, d_seg)
out = torch.stack(outputs, dim=1)
return out.reshape(B, n_segments * self.segment_len, 3)


Expand Down Expand Up @@ -473,12 +469,12 @@ def __init__(
)
self.decoder = decoder

def forward(self, offsets, token_mask):
def forward(self, diffs, mask=None):
"""
Parameters
----------
offsets : torch.Tensor
Shape (B, N, 3), normalized to the unit sphere, offsets[:, 0] == 0.
diffs : torch.Tensor
Shape (B, N, 3), normalized to the unit sphere, diffs[:, 0] == 0.

Returns
-------
Expand All @@ -487,13 +483,19 @@ def forward(self, offsets, token_mask):
z : torch.Tensor
Latent vector of shape (B, latent_dim).
"""
z, encoder_tokens = self.encoder(offsets, token_mask)
n_segments = offsets.shape[1] // self.decoder.segment_len
z, encoder_tokens = self.encoder(diffs, mask)
if mask is not None:
valid_lengths = (~mask).sum(dim=1) # (B,)
n_segments = (
valid_lengths.min() // self.decoder.segment_len
).item()
else:
n_segments = diffs.shape[1] // self.decoder.segment_len
reconstruction = self.decoder(z, n_segments=n_segments)
return reconstruction, z

def encode(self, offsets):
z, _ = self.encoder(offsets)
def encode(self, diffs):
z, _ = self.encoder(diffs)
return z

# --- Helpers ---
Expand Down
179 changes: 179 additions & 0 deletions src/neuron_proofreader/geometric_learning/curve_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""
Created on Mon June 8 17:00:00 2026

@author: Anna Grim
@email: anna.grim@alleninstitute.org

Code for visualizing 3D space curves and their embeddings.

"""

from colorsys import hsv_to_rgb
from matplotlib.colors import LogNorm
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go

from neuron_proofreader.utils import geometry_util


# --- Plot Curves ---
def plot_curves(curve1, curve2, name1=None, name2=None):
fig = go.Figure(
data=[
go.Scatter3d(
x=curve1[:, 0],
y=curve1[:, 1],
z=curve1[:, 2],
mode="lines",
name=name1,
line=dict(width=5, color="blue"),
),
go.Scatter3d(
x=curve2[:, 0],
y=curve2[:, 1],
z=curve2[:, 2],
mode="lines",
name=name2,
line=dict(width=5, color="green"),
),
go.Scatter3d(
x=[0],
y=[0],
z=[0],
mode="markers",
name="Origin",
marker=dict(size=3, color="red"),
),
]
)
fig.update_layout(
scene=dict(xaxis_title="x", yaxis_title="y", zaxis_title="z")
)
fig.show()


def plot_length_distribution(dataset_collection, title=None, output_path=None):
# Compute path length stats
lengths = dataset_collection.examples_df["length"]
p50 = round(np.percentile(lengths, 50), 2)
p99 = round(np.percentile(lengths, 99.9), 2)
thr_lengths = [l for l in lengths if l < p99]

# Plot path lengths
plt.figure(figsize=(8, 5))
plt.hist(thr_lengths, bins=50, edgecolor="white", linewidth=0.5, zorder=2)
add_line(p50, color="r", label=f"50th perc = {p50}")
add_line(p99, color="g", label=f"99.9th perc = {p99}")

# Plot labels
plt.grid(axis="y", color="lightgrey", linewidth=0.5)
plt.legend(loc="upper right")
plt.title(title, fontsize=13)
plt.xlabel("Path Length (μm)", fontsize=12)
plt.ylabel("Count", fontsize=12)
plt.yscale("log")

plt.figtext(
0.01,
-0.02,
"Note: Path lengths thresholded at 99.9th percentile in this plot",
fontsize=9,
color="tab:grey",
va="bottom",
)
plt.subplots_adjust(bottom=0.8)
visualize_result(output_path=output_path)


# --- Plot Curve Embeddings ---
def plot_error_vs_length(lengths, rmse_results, output_path=None):
# Set colors
norm = LogNorm(vmin=lengths.min(), vmax=lengths.max())
colors = plt.cm.viridis(norm(lengths))

# Plot
plt.figure(figsize=(8, 6))
plt.scatter(lengths, rmse_results, c=colors, s=10, alpha=0.8)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Path Length (μm)", fontsize=12)
plt.ylabel("RMSE", fontsize=12)
visualize_result(output_path=output_path)


def plot_latents_by_pca(curves, latents, output_path=None):
# PCA of latents
pca = PCA(n_components=2)
latents_2d = pca.fit_transform(latents)
lengths = np.array([geometry_util.compute_length(c) for c in curves])

# Visualize results
_plot_latents_by_direction(curves, latents_2d, pca, output_path)
_plot_latents_by_length(lengths, latents_2d, pca, output_path)


def _plot_latents_by_direction(curves, latents_2d, pca, output_path=None):
# Compute directions and colors for each curve
directions = np.array([curve_principal_direction(c) for c in curves])
colors = np.array([direction_to_color(d) for d in directions])

# Plot
plt.figure(figsize=(8, 6))
plt.scatter(latents_2d[:, 0], latents_2d[:, 1], c=colors, s=10, alpha=0.8)
plt.xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)")
plt.ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)")
plt.title("Curve Embeddings Colored by Principal Direction")
visualize_result(output_path=output_path)


def _plot_latents_by_length(lengths, latents_2d, pca, output_path=None):
plt.figure(figsize=(8, 6))
sc = plt.scatter(
latents_2d[:, 0],
latents_2d[:, 1],
c=lengths,
cmap="viridis",
s=10,
alpha=0.7,
norm=LogNorm(vmin=lengths.min(), vmax=lengths.max()),
)
plt.colorbar(sc, label="Path length (μm)")
plt.xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)")
plt.ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)")
plt.title("PCA of curve embeddings")
visualize_result(output_path=output_path)


# --- Helpers ---
def add_line(p, color=None, label=None):
plt.axvline(p, color=color, linestyle="--", label=label, zorder=2)


def curve_principal_direction(curve):
curve_pca = PCA(n_components=1)
curve_pca.fit(curve)
direction = curve_pca.components_[0]
if direction[2] < 0:
direction = -direction
return direction / np.linalg.norm(direction)


def direction_to_color(direction):
x, y, z = direction
azimuth = np.arctan2(y, x)
hue = (azimuth + np.pi) / (2 * np.pi)
polar = np.arccos(np.clip(z, 0, 1))
saturation = polar / (np.pi / 2)
value = 1.0
return hsv_to_rgb(hue, saturation, value)


def visualize_result(output_path=None):
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=300, bbox_inches="tight")
else:
plt.show()
31 changes: 0 additions & 31 deletions src/neuron_proofreader/machine_learning/vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

"""

# from neurobase.finetune import finetune_model
from einops import rearrange

import torch
Expand Down Expand Up @@ -131,36 +130,6 @@ def forward(self, x):


# --- Transformers ---
class MAE3D(nn.Module):

def __init__(self, checkpoint_path, model_config):
# Call parent class
super().__init__()

# Load model
full_model = finetune_model(
checkpoint_path=checkpoint_path,
model_config=model_config,
task_head_config="binary_classifier",
freeze_encoder=True,
)

# Instance attributes
self.encoder = full_model.encoder
self.output = FeedForwardNet(384, 1, 3)

def forward(self, x):
latent0 = self.encoder(x[:, 0:1, ...])
latent1 = self.encoder(x[:, 1:2, ...])

x0 = latent0["latents"][:, 0, :]
x1 = latent1["latents"][:, 0, :]

x = torch.cat((x0, x1), dim=1)
x = self.output(x)
return x


class ViT3D(nn.Module):
"""
A class that implements a 3D Vision transformer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def __iter__(self):

# Split into batches upfront
batch_idx_groups = [
idxs[start: min(start + self.batch_size, len(idxs))]
idxs[start : min(start + self.batch_size, len(idxs))]
for start in range(0, len(idxs), self.batch_size)
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
DetectionBatchLoader,
DetectionPatchLoader,
)
from neuron_proofreader.utils import img_util, ml_util, util
from neuron_proofreader.utils import img_util, util


# --- Datasets ---
Expand Down
Loading
Loading