Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
dist/
.DS_Store
src/.DS_Store
src/.DS_Store
src/compass/__pycache__
3 changes: 2 additions & 1 deletion src/compass/ModelTransfuser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
import pickle
import tqdm

import torch
import torch.nn as nn
Expand Down Expand Up @@ -245,7 +246,7 @@ def compare(self, x, err=None, condition_mask=None,
self.softmax = nn.Softmax(dim=0)

# Loop over all models
for model_name, model in self.models_dict.items():
for model_name, model in tqdm.tqdm(self.models_dict.items(), desc="Comparing models", unit="model"):
self.stats[model_name] = {}
if condition_mask is None:
condition_mask = torch.cat([torch.zeros(model.nodes_size-x.shape[-1]),torch.ones(x.shape[-1])])
Expand Down
141 changes: 82 additions & 59 deletions src/compass/Sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,27 +223,50 @@ def _gather_samples(self, all_samples, indices, result_dict):
#############################################

def _get_score(self, x, t, condition_mask, cfg_alpha=None):
"""Get score estimate with optional classifier-free guidance"""
# Get conditional score
"""Get score estimate with optional classifier-free guidance.

Handles both 2D input (num_samples, nodes_size) and
3D input (num_obs, num_samples, nodes_size) by flattening.
"""
input_shape = x.shape

# Flatten if 3D: (num_obs, num_samples, N) → (num_obs * num_samples, N)
if x.dim() == 3:
num_obs, num_samples, nodes_size = x.shape
x_flat = x.reshape(-1, nodes_size)
c_flat = condition_mask.reshape(-1, nodes_size)
else:
x_flat = x
c_flat = condition_mask

with torch.no_grad():
# Check if Attention weights should be returned
if t.item() == self.attn_weights_time and self.return_attn_weights:
score_cond, attn_weights = self.model(x=x, t=t, c=condition_mask, return_attn_weights=True)
# Attention weight extraction (only once per sampling run)
if (abs(t.flatten()[0].item() - self.attn_weights_time.item()) < 1e-6
and self.return_attn_weights):
score_cond, attn_weights = self.model(
x=x_flat, t=t, c=c_flat, return_attn_weights=True
)
self.all_attn_weights.append(attn_weights)
self.return_attn_weights = False # Only return once per sample
self.return_attn_weights = False
else:
score_cond = self.model(x=x, t=t, c=condition_mask)

score_cond = self.model(x=x_flat, t=t, c=c_flat)
score_cond = self.SBIm.output_scale_function(t, score_cond)

# Apply classifier-free guidance if requested
# Classifier-free guidance
if cfg_alpha is not None:
score_uncond = self.model(x=x, t=t, c=torch.zeros_like(condition_mask))
score_uncond = self.model(
x=x_flat, t=t, c=torch.zeros_like(c_flat)
)
score_uncond = self.SBIm.output_scale_function(t, score_uncond)
score = score_uncond + cfg_alpha * (score_cond - score_uncond)
else:
score = score_cond


# Unflatten back to original shape
if len(input_shape) == 3:
score = score.reshape(input_shape)

return score

def _check_data_shape(self, data, condition_mask, err):
Expand Down Expand Up @@ -349,25 +372,25 @@ def _basic_sampler(self, data, condition_mask):
self.data_t[:,0,:,:] = data

# Main sampling loop
for n in tqdm.tqdm(range(len(data)), disable=not self.verbose):
for i, t in enumerate(self.timesteps_list):

t = t.reshape(-1, 1)
# Get score estimate
score = self._get_score(data[n,:], t, condition_mask[n,:], self.cfg_alpha)
# Update step
dx = self.sde.sigma**(2*t) * score * self.dt
# Apply update respecting condition mask
data[n,:] = data[n,:] + dx * (1-condition_mask[n,:])
if self.save_trajectory:
# Store trajectory data
self.data_t[n,i+1] = data[n,:]
self.dx_t[n,i] = dx
self.score_t[n,i] = score

for i, t in tqdm.tqdm(enumerate(self.timesteps_list), disable=not self.verbose):

t = t.reshape(-1, 1)

# Get score estimate
score = self._get_score(data, t, condition_mask, self.cfg_alpha)

# Update step
dx = self.sde.sigma**(2*t) * score * self.dt

# Apply update respecting condition mask
data = data + dx * (1-condition_mask)

if self.save_trajectory:
# Store trajectory data
self.data_t[:,i+1,:,:] = data
self.dx_t[:,i,:,:] = dx
self.score_t[:,i,:,:] = score

return data.detach()

Expand Down Expand Up @@ -475,35 +498,35 @@ def _dpm_sampler(self, data, condition_mask,
self.data_t[:,0,:,:] = data

# Main sampling loop
for n in tqdm.tqdm(range(len(data)), disable=not self.verbose):
for i in range(self.timesteps-1):
# ------- PREDICTOR: DPM-Solver -------
t_now = self.timesteps_list[i].reshape(-1, 1)
t_next = self.timesteps_list[i+1].reshape(-1, 1)

if order == 1:
data[n,:] = self._dpm_solver_1_step(data[n,:], t_now, t_next, condition_mask[n,:])
elif order == 2:
data[n,:] = self._dpm_solver_2_step(data[n,:], t_now, t_next, condition_mask[n,:])
elif order == 3:
data[n,:] = self._dpm_solver_3_step(data[n,:], t_now, t_next, condition_mask[n,:])
else:
raise ValueError(f"Only orders 1, 2 or 3 are supported in the DPM-Solver.")
# ------- CORRECTOR: Langevin MCMC steps -------
# Only apply corrector steps occasionally to save computation
if corrector_steps > 0 and (i % corrector_steps_interval == 0 or i >= self.timesteps - final_corrector_steps):
steps = corrector_steps
if i >= self.timesteps - final_corrector_steps:
steps = corrector_steps * 2 # More steps at the end
data[n,:] = self._corrector_step(data[n,:], t_next, condition_mask[n,:],
steps, snr, self.cfg_alpha)

if self.save_trajectory:
# Store trajectory data
self.data_t[n,i+1] = data[n,:]

for i in tqdm.tqdm(range(self.timesteps-1), disable=not self.verbose):

# ------- PREDICTOR: DPM-Solver -------
t_now = self.timesteps_list[i].reshape(-1, 1)
t_next = self.timesteps_list[i+1].reshape(-1, 1)

if order == 1:
data = self._dpm_solver_1_step(data, t_now, t_next, condition_mask)
elif order == 2:
data = self._dpm_solver_2_step(data, t_now, t_next, condition_mask)
elif order == 3:
data = self._dpm_solver_3_step(data, t_now, t_next, condition_mask)
else:
raise ValueError(f"Only orders 1, 2 or 3 are supported in the DPM-Solver.")

# ------- CORRECTOR: Langevin MCMC steps -------
# Only apply corrector steps occasionally to save computation
if corrector_steps > 0 and (i % corrector_steps_interval == 0 or i >= self.timesteps - final_corrector_steps):
steps = corrector_steps
if i >= self.timesteps - final_corrector_steps:
steps = corrector_steps * 2 # More steps at the end

data = self._corrector_step(data, t_next, condition_mask,
steps, snr, self.cfg_alpha)

if self.save_trajectory:
# Store trajectory data
self.data_t[:,i+1] = data

return data.detach()

1,230 changes: 1,230 additions & 0 deletions tutorials/Gaussian_models_baselines.ipynb

Large diffs are not rendered by default.

Loading