Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions neurometry/datasets/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def hypersphere(intrinsic_dim, num_points, radius=1):
"""
unit_hypersphere = Hypersphere(dim=intrinsic_dim)
unit_hypersphere_points = unit_hypersphere.random_point(n_samples=num_points)
intrinsic_coords = unit_hypersphere.extrinsic_to_intrinsic_coords(unit_hypersphere_points)
intrinsic_coords = unit_hypersphere.extrinsic_to_intrinsic_coords(
unit_hypersphere_points
)
return radius * unit_hypersphere_points, intrinsic_coords


Expand Down Expand Up @@ -110,9 +112,14 @@ def hypertorus(intrinsic_dim, num_points, radii=None, parameterization="flat"):
hypertorus_points[:, _, :] = radii[_] * unit_hypertorus_points[:, _, :]
intrinsic_coords = torch.zeros(num_points, intrinsic_dim)
for i, factor in enumerate(unit_hypertorus.factors):
intrinsic_coords[:, i] = factor.extrinsic_to_intrinsic_coords(hypertorus_points[:, i, :]).squeeze()
intrinsic_coords[:, i] = factor.extrinsic_to_intrinsic_coords(
hypertorus_points[:, i, :]
).squeeze()

return gs.reshape(hypertorus_points, (num_points, intrinsic_dim * 2)), intrinsic_coords
return (
gs.reshape(hypertorus_points, (num_points, intrinsic_dim * 2)),
intrinsic_coords,
)


def cylinder(num_points, radius=1):
Expand All @@ -132,7 +139,9 @@ def cylinder(num_points, radius=1):
cylinder = ProductManifold(factors=factors)
cylinder_points = cylinder.random_point(n_samples=num_points, bound=1)
intrinsic_coords = torch.zeros(num_points, 2)
intrinsic_coords[:, 0] = factors[0].extrinsic_to_intrinsic_coords(cylinder_points[:, :2]).squeeze()
intrinsic_coords[:, 0] = (
factors[0].extrinsic_to_intrinsic_coords(cylinder_points[:, :2]).squeeze()
)
intrinsic_coords[:, 1] = cylinder_points[:, 2]
cylinder_points[:, :2] = radius * cylinder_points[:, :2]
return cylinder_points, intrinsic_coords
Expand Down
3 changes: 3 additions & 0 deletions neurometry/estimators/geometry/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Initialize folder as submodule."""

__version__ = "0.0.1"
163 changes: 163 additions & 0 deletions neurometry/estimators/geometry/immersion_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import os

os.environ["GEOMSTATS_BACKEND"] = "pytorch"
import geomstats.backend as gs
import torch
import torch.optim as optim
from sklearn.base import BaseEstimator
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset

import neurometry.estimators.geometry.models.train_config as train_config


class ImmersionEstimator(BaseEstimator):
def __init__(self, extrinsic_dim, topology, device, verbose=False):
self.estimate_ = None
self.verbose = verbose
self.device = device
self.extrinsic_dim = extrinsic_dim
self.topology = topology
self.latent_dims = {"circle": 2, "sphere": 3, "torus": 4}
self.model = self._get_model()

def _get_model(self):
return NeuralEmbedding(
latent_dim=self.latent_dims[self.topology], extrinsic_dim=self.extrinsic_dim
).to(self.device)

def intrinsic_to_extrinsic(self, x):
if self.topology == "circle":
return gs.array([gs.cos(x), gs.sin(x)]).T
if self.topology == "sphere":
return gs.array(
[
gs.sin(x[:, 0]) * gs.cos(x[:, 1]),
gs.sin(x[:, 0]) * gs.sin(x[:, 1]),
gs.cos(x[:, 0]),
]
).T
if self.topology == "torus":
return gs.array(
[
gs.cos(x[:, 0]),
gs.sin(x[:, 0]),
gs.cos(x[:, 1]),
gs.sin(x[:, 1]),
]
).T
raise ValueError("Topology not supported")

def fit(self, X, y=None):
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
X_train = torch.tensor(X_train).to(self.device)
X_test = torch.tensor(X_test).to(self.device)
y_train = torch.tensor(y_train).to(self.device)
y_test = torch.tensor(y_test).to(self.device)
train_loader = DataLoader(
TensorDataset(X_train, y_train),
batch_size=train_config.batch_size,
shuffle=True,
)
test_loader = DataLoader(
TensorDataset(X_test, y_test),
batch_size=train_config.batch_size,
shuffle=True,
)

trainer = Trainer(
self.model,
train_loader,
test_loader,
criterion=torch.nn.MSELoss(),
learning_rate=train_config.lr,
scheduler=False,
verbose=self.verbose,
)
trainer.train(train_config.num_epochs)

self.trainer = trainer

self.model.eval()
self.estimate_ = lambda task_variable: self.model(
self.intrinsic_to_extrinsic(task_variable)
)

return self


class NeuralEmbedding(torch.nn.Module):
def __init__(
self, latent_dim, extrinsic_dim, hidden_dims=64, num_hidden=4, sft_beta=4.5
):
super().__init__()

self.fc1 = torch.nn.Linear(latent_dim, hidden_dims)
self.fc_hidden = torch.nn.ModuleList(
[torch.nn.Linear(hidden_dims, hidden_dims) for _ in range(num_hidden)]
)
self.fc_output = torch.nn.Linear(hidden_dims, extrinsic_dim)
self.softplus = torch.nn.Softplus(beta=sft_beta)

def forward(self, x):
h = self.softplus(self.fc1(x))
for fc in self.fc_hidden:
h = self.softplus(fc(h))
return self.fc_output(h)


class Trainer:
def __init__(
self,
model,
train_loader,
test_loader,
criterion,
learning_rate,
scheduler=False,
verbose=False,
):
self.model = model
self.train_loader = train_loader
self.test_loader = test_loader
self.criterion = criterion
self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
if scheduler:
self.scheduler = optim.lr_scheduler.StepLR(
self.optimizer, step_size=10, gamma=0.1
)
self.verbose = verbose

def train(self, num_epochs=10):
train_losses = []
test_losses = []
for epoch in range(num_epochs):
self.model.train()
train_loss = 0.0
for inputs, targets in self.train_loader:
self.optimizer.zero_grad()
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
loss.backward()
self.optimizer.step()
train_loss += loss.item()
avg_train_loss = train_loss / len(self.train_loader)
train_losses.append(avg_train_loss)
avg_test_loss = self.evaluate()
test_losses.append(avg_test_loss)
if self.verbose:
print(
f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss}, Test Loss: {avg_test_loss}"
)
self.train_losses = train_losses
self.test_losses = test_losses

def evaluate(self):
self.model.eval()
test_loss = 0.0
with torch.no_grad():
for inputs, targets in self.test_loader:
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
test_loss += loss.item()
return test_loss / len(self.test_loader)
18 changes: 18 additions & 0 deletions neurometry/estimators/geometry/metric_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os

os.environ["GEOMSTATS_BACKEND"] = "pytorch"
from geomstats.geometry.base import ImmersedSet
from geomstats.geometry.euclidean import Euclidean


class NeuralManifoldIntrinsic(ImmersedSet):
def __init__(self, dim, neural_embedding_dim, neural_immersion, equip=True):
self.neural_embedding_dim = neural_embedding_dim
super().__init__(dim=dim, equip=equip)
self.neural_immersion = neural_immersion

def immersion(self, point):
return self.neural_immersion(point)

def _define_embedding_space(self):
return Euclidean(dim=self.neural_embedding_dim)
3 changes: 3 additions & 0 deletions neurometry/estimators/geometry/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Initialize folder as submodule."""

__version__ = "0.0.1"
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import math

import torch


class HypersphericalUniform(torch.distributions.Distribution):
arg_constraints = {}
support = torch.distributions.constraints.real
has_rsample = False
_mean_carrier_measure = 0

@property
def dim(self):
return self._dim

@property
def device(self):
return self._device

@device.setter
def device(self, val):
self._device = val if isinstance(val, torch.device) else torch.device(val)

def __init__(self, dim, validate_args=None, device=None):
super().__init__(torch.Size([dim]), validate_args=validate_args)
self._dim = dim
self.device = device

def sample(self, shape=None):
if shape is None:
shape = torch.Size()
output = (
torch.distributions.Normal(0, 1)
.sample(
(shape if isinstance(shape, torch.Size) else torch.Size([shape]))
+ torch.Size([self._dim + 1])
)
.to(self.device)
)

return output / output.norm(dim=-1, keepdim=True)

def entropy(self):
return self.__log_surface_area()

def log_prob(self, x):
return -torch.ones(x.shape[:-1], device=self.device) * self.__log_surface_area()

def __log_surface_area(self):
if torch.__version__ >= "1.0.0":
lgamma = torch.lgamma(torch.tensor([(self._dim + 1) / 2]).to(self.device))
else:
lgamma = torch.lgamma(
torch.Tensor([(self._dim + 1) / 2], device=self.device)
)
return math.log(2) + ((self._dim + 1) / 2) * math.log(math.pi) - lgamma
Loading
Loading