diff --git a/rfdiffusion/Attention_module.py b/rfdiffusion/Attention_module.py index f8868fc2..0e345733 100644 --- a/rfdiffusion/Attention_module.py +++ b/rfdiffusion/Attention_module.py @@ -60,20 +60,14 @@ def forward(self, query, key, value): B, Q = query.shape[:2] B, K = key.shape[:2] # - query = self.to_q(query).reshape(B, Q, self.h, self.dim) - key = self.to_k(key).reshape(B, K, self.h, self.dim) - value = self.to_v(value).reshape(B, K, self.h, self.dim) - # - query = query * self.scaling - attn = einsum('bqhd,bkhd->bhqk', query, key) - attn = F.softmax(attn, dim=-1) - # - out = einsum('bhqk,bkhd->bqhd', attn, value) - out = out.reshape(B, Q, self.h*self.dim) - # - out = self.to_out(out) - - return out + # (B, seq, h, d) -> (B, h, seq, d) for scaled_dot_product_attention + query = self.to_q(query).reshape(B, Q, self.h, self.dim).transpose(1, 2) + key = self.to_k(key ).reshape(B, K, self.h, self.dim).transpose(1, 2) + value = self.to_v(value).reshape(B, K, self.h, self.dim).transpose(1, 2) + # scaling and softmax handled internally; uses Flash Attention when available + out = F.scaled_dot_product_attention(query, key, value) # (B, h, Q, d) + out = out.transpose(1, 2).reshape(B, Q, self.h * self.dim) + return self.to_out(out) class AttentionWithBias(nn.Module): def __init__(self, d_in=256, d_bias=128, n_head=8, d_hidden=32): @@ -117,22 +111,17 @@ def forward(self, x, bias): x = self.norm_in(x) bias = self.norm_bias(bias) # - query = self.to_q(x).reshape(B, L, self.h, self.dim) - key = self.to_k(x).reshape(B, L, self.h, self.dim) - value = self.to_v(x).reshape(B, L, self.h, self.dim) - bias = self.to_b(bias) # (B, L, L, h) - gate = torch.sigmoid(self.to_g(x)) - # - key = key * self.scaling - attn = einsum('bqhd,bkhd->bqkh', query, key) - attn = attn + bias - attn = F.softmax(attn, dim=-2) - # - out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1) + # (B, L, h, d) -> (B, h, L, d); bias (B, L, L, h) -> (B, h, L, L) + query = self.to_q(x).reshape(B, L, self.h, self.dim).transpose(1, 2) + key = self.to_k(x).reshape(B, L, self.h, self.dim).transpose(1, 2) + value = self.to_v(x).reshape(B, L, self.h, self.dim).transpose(1, 2) + bias = self.to_b(bias).permute(0, 3, 1, 2) # (B, h, L, L) + gate = torch.sigmoid(self.to_g(x)) + # bias added to logits before softmax; Flash Attention used when available + out = F.scaled_dot_product_attention(query, key, value, attn_mask=bias) + out = out.transpose(1, 2).reshape(B, L, -1) # (B, L, h*d) out = gate * out - # - out = self.to_out(out) - return out + return self.to_out(out) # MSA Attention (row/column) from AlphaFold architecture class SequenceWeight(nn.Module): @@ -265,19 +254,20 @@ def forward(self, msa): msa = self.norm_msa(msa) # query = self.to_q(msa).reshape(B, N, L, self.h, self.dim) - key = self.to_k(msa).reshape(B, N, L, self.h, self.dim) + key = self.to_k(msa).reshape(B, N, L, self.h, self.dim) value = self.to_v(msa).reshape(B, N, L, self.h, self.dim) - gate = torch.sigmoid(self.to_g(msa)) - # - query = query * self.scaling - attn = einsum('bqihd,bkihd->bihqk', query, key) - attn = F.softmax(attn, dim=-1) - # - out = einsum('bihqk,bkihd->bqihd', attn, value).reshape(B, N, L, -1) + gate = torch.sigmoid(self.to_g(msa)) + # Column attention: for each residue position, attend across N sequences. + # Reshape to (B*L, h, N, d) so scaled_dot_product_attention operates over N. + q = query.permute(0, 2, 3, 1, 4).reshape(B * L, self.h, N, self.dim) + k = key .permute(0, 2, 3, 1, 4).reshape(B * L, self.h, N, self.dim) + v = value.permute(0, 2, 3, 1, 4).reshape(B * L, self.h, N, self.dim) + out = F.scaled_dot_product_attention(q, k, v) # (B*L, h, N, d) + out = (out.reshape(B, L, self.h, N, self.dim) + .permute(0, 3, 1, 2, 4) + .reshape(B, N, L, -1)) out = gate * out - # - out = self.to_out(out) - return out + return self.to_out(out) class MSAColGlobalAttention(nn.Module): def __init__(self, d_msa=64, n_head=8, d_hidden=8): diff --git a/rfdiffusion/diffusion.py b/rfdiffusion/diffusion.py index a67e5794..14261492 100644 --- a/rfdiffusion/diffusion.py +++ b/rfdiffusion/diffusion.py @@ -3,46 +3,56 @@ import pickle import numpy as np import os +import math import logging -from scipy.spatial.transform import Rotation as scipy_R - from rfdiffusion.util import rigid_from_3_points - from rfdiffusion.util_module import ComputeAllAtomCoords - from rfdiffusion import igso3 import time +# Module-level cache so IGSO3 lookup tables survive across Diffuser instantiations +# (avoids redundant disk I/O when generating batches of designs). +_igso3_cache: dict = {} + torch.set_printoptions(sci_mode=False) def get_beta_schedule(T, b0, bT, schedule_type, schedule_params={}, inference=False): """ - Given a noise schedule type, create the beta schedule - """ - assert schedule_type in ["linear"] + Given a noise schedule type, create the beta schedule. - # Adjust b0 and bT if T is not 200 - # This is a good approximation, with the beta correction below, unless T is very small + schedule_type options: + "linear" — Ho et al. (2020) linear schedule, scaled to T steps. + "cosine" — Nichol & Dhariwal (2021) cosine schedule; b0/bT ignored. + """ + assert schedule_type in ["linear", "cosine"], ( + f"Unknown schedule type '{schedule_type}'. Choose 'linear' or 'cosine'." + ) assert T >= 15, "With discrete time and T < 15, the schedule is badly approximated" - b0 *= 200 / T - bT *= 200 / T - # linear noise schedule if schedule_type == "linear": + # Scale endpoints to be equivalent to a 200-step schedule + b0 *= 200 / T + bT *= 200 / T schedule = torch.linspace(b0, bT, T) - else: - raise NotImplementedError(f"Schedule of type {schedule_type} not implemented.") + elif schedule_type == "cosine": + # Cosine schedule from Nichol & Dhariwal (2021), Improved DDPM + s = schedule_params.get("s", 0.008) + steps = torch.arange(T + 1, dtype=torch.float64) + f = torch.cos((steps / T + s) / (1.0 + s) * math.pi / 2.0) ** 2 + alphabar = (f / f[0]).float() + schedule = torch.clamp(1.0 - alphabar[1:] / alphabar[:-1], max=0.999) - # get alphabar_t for convenience - alpha_schedule = 1 - schedule + alpha_schedule = 1.0 - schedule alphabar_t_schedule = torch.cumprod(alpha_schedule, dim=0) if inference: print( - f"With this beta schedule ({schedule_type} schedule, beta_0 = {round(b0, 3)}, beta_T = {round(bT,3)}), alpha_bar_T = {alphabar_t_schedule[-1]}" + f"Beta schedule: {schedule_type}, " + f"beta_0={schedule[0].item():.5f}, beta_T={schedule[-1].item():.5f}, " + f"alpha_bar_T={alphabar_t_schedule[-1].item():.5f}" ) return schedule, alpha_schedule, alphabar_t_schedule @@ -228,6 +238,10 @@ def _calc_igso3_vals(self, L=2000): if not os.path.isdir(self.cache_dir): os.makedirs(self.cache_dir) + if cache_fname in _igso3_cache: + self._log.info("Using in-memory IGSO3 cache.") + return _igso3_cache[cache_fname] + if os.path.exists(cache_fname): self._log.info("Using cached IGSO3.") igso3_vals = read_pkl(cache_fname) @@ -241,6 +255,7 @@ def _calc_igso3_vals(self, L=2000): ) write_pkl(cache_fname, igso3_vals) + _igso3_cache[cache_fname] = igso3_vals return igso3_vals @property @@ -288,23 +303,29 @@ def sigma(self, t: torch.tensor): def g(self, t): """ - g returns the drift coefficient at time t + g returns the drift coefficient at time t. - since - sigma(t)^2 := \int_0^t g(s)^2 ds, - for arbitrary sigma(t) we invert this relationship to compute - g(t) = sqrt(d/dt sigma(t)^2). + g(t) = sqrt(d/dt sigma(t)^2) - Args: - t: scalar time between 0 and 1 + For the linear schedule sigma(t) = min_sigma + t*min_b + 0.5*t^2*(max_b - min_b), + we derive analytically: + d/dt sigma(t)^2 = 2*sigma(t) * (min_b + t*(max_b - min_b)) + which avoids a per-step autograd call. - Returns: - drift cooeficient as a scalar. + For the exponential schedule, autograd is still used as a fallback. """ - t = torch.tensor(t, requires_grad=True) - sigma_sqr = self.sigma(t) ** 2 - grads = torch.autograd.grad(sigma_sqr.sum(), t)[0] - return torch.sqrt(grads) + if not torch.is_tensor(t): + t = torch.tensor(t, dtype=torch.float32) + + if self.schedule == "linear": + sigma_t = self.sigma(t) + dsigma_dt = self.min_b + t * (self.max_b - self.min_b) + return torch.sqrt(2.0 * sigma_t * dsigma_dt) + else: + t = t.requires_grad_(True) + sigma_sqr = self.sigma(t) ** 2 + grads = torch.autograd.grad(sigma_sqr.sum(), t)[0] + return torch.sqrt(grads) def sample(self, ts, n_samples=1): """ @@ -427,12 +448,9 @@ def diffuse_frames(self, xyz, t_list, diffusion_mask=None): non_diffusion_mask = 1 - diffusion_mask[None, :, None] sampled_rots = sampled_rots * non_diffusion_mask - # Apply sampled rot. - R_sampled = ( - scipy_R.from_rotvec(sampled_rots.reshape(-1, 3)) - .as_matrix() - .reshape(self.T, num_res, 3, 3) - ) + # Apply sampled rot — torch-native Exp map avoids scipy/CPU roundtrip. + sampled_rots_t = torch.from_numpy(sampled_rots.reshape(-1, 3)).float() + R_sampled = igso3.Exp_torch(sampled_rots_t).numpy().reshape(self.T, num_res, 3, 3) R_perturbed = np.einsum("tnij,njk->tnik", R_sampled, R_true) perturbed_crds = ( np.einsum( @@ -494,11 +512,10 @@ def reverse_sample_vectorized( differential equations. arXiv preprint arXiv:2011.13456. """ # compute rotation vector corresponding to prediction of how r_t goes to r_0 - R_0, R_t = torch.tensor(R_0), torch.tensor(R_t) + R_0, R_t = torch.as_tensor(R_0), torch.as_tensor(R_t) R_0t = torch.einsum("...ij,...kj->...ik", R_t, R_0) - R_0t_rotvec = torch.tensor( - scipy_R.from_matrix(R_0t.cpu().numpy()).as_rotvec() - ).to(R_0.device) + # torch-native Log map: stays on-device, no CPU/scipy roundtrip + R_0t_rotvec = igso3.Log_torch(R_0t).to(dtype=torch.float32, device=R_0.device) # Approximate the score based on the prediction of R0. # R_t @ hat(Score_approx) is the score approximation in the Lie algebra @@ -527,7 +544,8 @@ def reverse_sample_vectorized( Perturb_tangent = Delta_r + rot_g * np.sqrt(self.step_size) * Z if mask is not None: Perturb_tangent *= (1 - mask.long())[:, None, None] - Perturb = igso3.Exp(Perturb_tangent) + # torch-native Exp map: stays on-device, no scipy roundtrip + Perturb = igso3.Exp_torch(Perturb_tangent) if return_perturb: return Perturb diff --git a/rfdiffusion/igso3.py b/rfdiffusion/igso3.py index 6d90bdb2..258dd13c 100644 --- a/rfdiffusion/igso3.py +++ b/rfdiffusion/igso3.py @@ -15,14 +15,92 @@ def hat(v): hat_v[:, 0, 1], hat_v[:, 0, 2], hat_v[:, 1, 2] = -v[:, 2], v[:, 1], -v[:, 0] return hat_v + -hat_v.transpose(2, 1) -# Logarithmic map from SO(3) to R^3 (i.e. rotation vector) +def hat_batch(v): + """Batch hat map: [..., 3] -> [..., 3, 3] (cross-product / skew-symmetric matrix).""" + bshape = v.shape[:-1] + h = torch.zeros(*bshape, 3, 3, device=v.device, dtype=v.dtype) + h[..., 0, 1] = -v[..., 2] + h[..., 0, 2] = v[..., 1] + h[..., 1, 0] = v[..., 2] + h[..., 1, 2] = -v[..., 0] + h[..., 2, 0] = -v[..., 1] + h[..., 2, 1] = v[..., 0] + return h + +def Log_torch(R): + """On-device rotation matrix -> rotation vector. R: [..., 3, 3] -> [..., 3]. + Stays on the original device — no scipy or CPU transfers. + Numerically stable across the full [0, pi] range: + - Uses ||skew|| = 2*sin(theta) for theta when cos(theta) < 0 (avoids trace + instability near pi where float32 R loses precision in the trace but skew + elements remain accurate). + - Falls back to R+I decomposition only when sin(theta) is sub-epsilon + (skew elements are below float32 resolution, i.e. theta very close to pi). + """ + orig_dtype = R.dtype + R64 = R.to(torch.float64) + trace = R64[..., 0, 0] + R64[..., 1, 1] + R64[..., 2, 2] + cos_theta = torch.clamp((trace - 1.0) / 2.0, -1.0, 1.0) + + # Skew-symmetric part: (R - R^T)_vee = 2*sin(theta)*n_vec (sign-correct for all theta) + skew = torch.stack([ + R64[..., 2, 1] - R64[..., 1, 2], + R64[..., 0, 2] - R64[..., 2, 0], + R64[..., 1, 0] - R64[..., 0, 1], + ], dim=-1) + skew_norm = torch.norm(skew, dim=-1) # = 2*|sin(theta)| + axis = skew / torch.clamp(skew_norm, min=1e-12)[..., None] + + # Theta: acos(cos_theta) for small angles; pi - asin(skew_norm/2) near pi. + # The asin estimate uses the skew magnitude directly, avoiding trace instability. + theta_trace = torch.acos(cos_theta) + theta_asin = skew.new_full(skew_norm.shape, np.pi) - torch.asin(torch.clamp(skew_norm / 2.0, 0.0, 1.0)) + theta = torch.where(cos_theta < 0.0, theta_asin, theta_trace) + rotvec_std = theta[..., None] * axis + + # Near-pi fallback: when sin(theta) < float32 noise floor in R, skew -> 0 but + # R + I = 2*outer(n,n) is still readable. Use R+I decomposition for axis. + diag = torch.stack([R64[..., 0, 0] + 1.0, R64[..., 1, 1] + 1.0, R64[..., 2, 2] + 1.0], dim=-1) + ax_mags = torch.sqrt(torch.clamp(diag / 2.0, min=0.0)) + ref = torch.argmax(diag, dim=-1, keepdim=True) + ref_row = torch.gather(R64, -2, ref.unsqueeze(-1).expand(*ref.shape[:-1], 1, 3)).squeeze(-2) + signs = torch.sign(ref_row + 1e-30) + ref_mask = torch.zeros_like(signs).scatter_(-1, ref, 1.0) + signs = signs * (1.0 - ref_mask) + ref_mask + ax_pi = ax_mags * signs + ax_pi = ax_pi / torch.norm(ax_pi, dim=-1, keepdim=True).clamp(min=1e-15) + rotvec_nearpi = ax_pi * theta[..., None] + + # Branch thresholds (based on cos_theta, stable for float32 R inputs): + # near_zero: cos ≈ 1 → identity rotation + # near_pi: cos ≈ -1 → skew magnitude below float32 noise (~3.5e-4) + near_zero = cos_theta[..., None] > (1.0 - 1e-10) + near_pi = cos_theta[..., None] < -(1.0 - 6.25e-8) + + rotvec = torch.where(near_zero, torch.zeros_like(rotvec_std), + torch.where(near_pi, rotvec_nearpi, rotvec_std)) + return rotvec.to(orig_dtype) + +def Exp_torch(v): + """On-device rotation vector -> rotation matrix. v: [..., 3] -> [..., 3, 3]. + Rodrigues formula. Stays on the original device/dtype.""" + theta = torch.norm(v, dim=-1) + theta_safe = torch.clamp(theta, min=1e-7) + axis = v / theta_safe[..., None] + K = hat_batch(axis) + I = torch.eye(3, device=v.device, dtype=v.dtype).expand(*v.shape[:-1], 3, 3) + sin_t = torch.sin(theta)[..., None, None] + cos_t = torch.cos(theta)[..., None, None] + R = I + sin_t * K + (1.0 - cos_t) * (K @ K) + return torch.where(theta[..., None, None] < 1e-7, I, R) + +# Logarithmic map from SO(3) to R^3 (i.e. rotation vector) — legacy CPU version def Log(R): return torch.tensor(Rotation.from_matrix(R.numpy()).as_rotvec()) - + # logarithmic map from SO(3) to so(3), this is the matrix logarithm def log(R): return hat(Log(R)) -# Exponential map from vector space of so(3) to SO(3), this is the matrix -# exponential combined with the "hat" map +# Exponential map from vector space of so(3) to SO(3) — legacy CPU version def Exp(A): return torch.tensor(Rotation.from_rotvec(A.numpy()).as_matrix()) # Angle of rotation SO(3) to R^+ diff --git a/rfdiffusion/inference/utils.py b/rfdiffusion/inference/utils.py index 3fb14112..9c72e713 100644 --- a/rfdiffusion/inference/utils.py +++ b/rfdiffusion/inference/utils.py @@ -4,7 +4,6 @@ import torch import torch.nn.functional as nn from rfdiffusion.diffusion import get_beta_schedule -from scipy.spatial.transform import Rotation as scipy_R from rfdiffusion.util import rigid_from_3_points from rfdiffusion.util_module import ComputeAllAtomCoords from rfdiffusion import util @@ -53,9 +52,9 @@ def get_next_frames(xt, px0, t, diffuser, so3_type, diffusion_mask, noise_scale= R_t, Ca_t = rigid_from_3_points(N_t, Ca_t, C_t) - # this must be to normalize them or something - R_0 = scipy_R.from_matrix(R_0.squeeze().numpy()).as_matrix() - R_t = scipy_R.from_matrix(R_t.squeeze().numpy()).as_matrix() + # rigid_from_3_points already returns proper rotation matrices; convert to numpy. + R_0 = R_0.squeeze().numpy() + R_t = R_t.squeeze().numpy() L = R_t.shape[0] all_rot_transitions = np.broadcast_to(np.identity(3), (L, 3, 3)).copy() @@ -122,6 +121,33 @@ def get_mu_xt_x0(xt, px0, t, beta_schedule, alphabar_schedule, eps=1e-6): return mu, sigma +def get_mu_xt_x0_ddim(xt, px0, t, alphabar_schedule, eps=1e-8): + """ + Deterministic DDIM update for Cα coordinates (Song et al., 2021). + + Unlike DDPM, DDIM skips the stochastic noise term and uses: + x_{t-1} = sqrt(alpha_bar_{t-1}) * x̂_0 + + sqrt(1 - alpha_bar_{t-1}) * epsilon_theta(x_t, t) + where epsilon_theta is the implied noise direction derived from x_t and x̂_0. + + Setting noise_scale=0 in DDPM is not equivalent — DDIM uses a different mean. + """ + t_idx = t - 1 + xt_ca = xt[:, 1, :] + px0_ca = px0[:, 1, :] + + alphabar_t = alphabar_schedule[t_idx] + alphabar_tm1 = alphabar_schedule[t_idx - 1] if t_idx > 0 else torch.ones(1, dtype=xt.dtype, device=xt.device) + + # Implied noise direction + eps_theta = (xt_ca - torch.sqrt(alphabar_t + eps) * px0_ca) / torch.sqrt(1.0 - alphabar_t + eps) + + # DDIM deterministic update + x_tm1 = torch.sqrt(alphabar_tm1) * px0_ca + torch.sqrt(1.0 - alphabar_tm1) * eps_theta + delta = x_tm1 - xt_ca + return delta + + def get_next_ca( xt, px0, @@ -131,6 +157,7 @@ def get_next_ca( beta_schedule, alphabar_schedule, noise_scale=1.0, + ddim=False, ): """ Given full atom x0 prediction (xyz coordinates), diffuse to x(t-1) @@ -155,24 +182,24 @@ def get_next_ca( get_allatom = ComputeAllAtomCoords().to(device=xt.device) L = len(xt) - # bring to origin after global alignment (when don't have a motif) or replace input motif and bring to origin, and then scale px0 = px0 * crd_scale - xt = xt * crd_scale + xt = xt * crd_scale - # get mu(xt, x0) - mu, sigma = get_mu_xt_x0( - xt, px0, t, beta_schedule=beta_schedule, alphabar_schedule=alphabar_schedule - ) - - sampled_crds = torch.normal(mu, torch.sqrt(sigma * noise_scale)) - delta = sampled_crds - xt[:, 1, :] # check sign of this is correct + if ddim: + # Deterministic DDIM update — faster convergence, no stochastic noise + delta = get_mu_xt_x0_ddim(xt, px0, t, alphabar_schedule=alphabar_schedule) + else: + # Stochastic DDPM update + mu, sigma = get_mu_xt_x0( + xt, px0, t, beta_schedule=beta_schedule, alphabar_schedule=alphabar_schedule + ) + sampled_crds = torch.normal(mu, torch.sqrt(sigma * noise_scale)) + delta = sampled_crds - xt[:, 1, :] if not diffusion_mask is None: - # Don't move motif delta[diffusion_mask, ...] = 0 out_crds = xt + delta[:, None, :] - return out_crds / crd_scale, delta / crd_scale @@ -243,13 +270,14 @@ def __init__( crd_scale=1 / 15, potential_manager=None, partial_T=None, + ddim=False, ): """ - Parameters: noise_level: scaling on the noise added (set to 0 to use no noise, to 1 to have full noise) - + ddim: use deterministic DDIM update for Cα coordinates instead of + stochastic DDPM. Enables fewer-step inference at equivalent quality. """ self.T = T self.L = L @@ -267,6 +295,7 @@ def __init__( self.final_noise_scale_frame = final_noise_scale_frame self.frame_noise_schedule_type = frame_noise_schedule_type self.potential_manager = potential_manager + self.ddim = ddim self._log = logging.getLogger(__name__) self.schedule, self.alpha_schedule, self.alphabar_schedule = get_beta_schedule( @@ -464,6 +493,7 @@ def get_next_pose( beta_schedule=self.schedule, alphabar_schedule=self.alphabar_schedule, noise_scale=noise_scale_ca, + ddim=self.ddim, ) # get the next set of backbone frames (coordinates) diff --git a/rfdiffusion/kinematics.py b/rfdiffusion/kinematics.py index 8d548394..f67cf372 100644 --- a/rfdiffusion/kinematics.py +++ b/rfdiffusion/kinematics.py @@ -47,7 +47,7 @@ def get_ang(a, b, c): w /= torch.norm(w, dim=-1, keepdim=True) vw = torch.sum(v*w, dim=-1) - return torch.acos(vw) + return torch.acos(torch.clamp(vw, -1.0, 1.0)) # ============================================================ def get_dih(a, b, c, d): diff --git a/scripts/benchmark_inference.py b/scripts/benchmark_inference.py new file mode 100644 index 00000000..b120b6ee --- /dev/null +++ b/scripts/benchmark_inference.py @@ -0,0 +1,234 @@ +""" +Benchmark script for PR #454 performance improvements. + +Measures: + 1. Flash Attention (F.scaled_dot_product_attention) vs manual einsum attention + 2. Torch-native SO(3) ops vs scipy (isolated + pipelined context) + 3. Log_torch accuracy including near-theta=pi + +Run with: python scripts/benchmark_pr454.py +""" +import time +import math +import torch +import torch.nn.functional as F +import numpy as np +from scipy.spatial.transform import Rotation + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +DTYPE = torch.float32 +N_WARMUP = 50 +N_TRIALS = 500 + +def timer(fn, warmup=N_WARMUP, trials=N_TRIALS): + for _ in range(warmup): + fn() + if DEVICE == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(trials): + fn() + if DEVICE == "cuda": + torch.cuda.synchronize() + return (time.perf_counter() - t0) / trials * 1000 # ms + +# ─── SO(3) helpers ──────────────────────────────────────────────────────────── + +def hat_batch(v): + bshape = v.shape[:-1] + h = torch.zeros(*bshape, 3, 3, device=v.device, dtype=v.dtype) + h[..., 0, 1] = -v[..., 2]; h[..., 0, 2] = v[..., 1] + h[..., 1, 0] = v[..., 2]; h[..., 1, 2] = -v[..., 0] + h[..., 2, 0] = -v[..., 1]; h[..., 2, 1] = v[..., 0] + return h + +def Log_torch(R): + """Three-branch Log: near-identity / standard / near-pi.""" + trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2] + theta = torch.acos(torch.clamp((trace - 1.0) / 2.0, -1.0, 1.0)) + skew = torch.stack([ + R[..., 2, 1] - R[..., 1, 2], + R[..., 0, 2] - R[..., 2, 0], + R[..., 1, 0] - R[..., 0, 1], + ], dim=-1) + rotvec_std = (theta / (2.0 * torch.clamp(torch.sin(theta), min=1e-7)))[..., None] * skew + # Near-pi branch: R + I = 2 * outer(n, n) + diag = torch.stack([R[..., 0, 0]+1.0, R[..., 1, 1]+1.0, R[..., 2, 2]+1.0], dim=-1) + ax_mags = torch.sqrt(torch.clamp(diag / 2.0, min=0.0)) + ref = torch.argmax(diag, dim=-1, keepdim=True) + ref_row = torch.gather(R, -2, ref.unsqueeze(-1).expand(*ref.shape[:-1], 1, 3)).squeeze(-2) + signs = torch.sign(ref_row + 1e-10) + ref_mask = torch.zeros_like(signs).scatter_(-1, ref, 1.0) + signs = signs * (1.0 - ref_mask) + ref_mask + ax_pi = ax_mags * signs + ax_pi = ax_pi / torch.norm(ax_pi, dim=-1, keepdim=True).clamp(min=1e-7) + rotvec_pi = ax_pi * theta[..., None] + near_zero = theta[..., None] < 1e-6 + near_pi = theta[..., None] > (math.pi - 1e-3) + return torch.where(near_zero, torch.zeros_like(rotvec_std), + torch.where(near_pi, rotvec_pi, rotvec_std)) + +def Exp_torch(v): + theta = torch.norm(v, dim=-1) + axis = v / torch.clamp(theta, min=1e-7)[..., None] + K = hat_batch(axis) + I = torch.eye(3, device=v.device, dtype=v.dtype).expand(*v.shape[:-1], 3, 3) + R = I + torch.sin(theta)[..., None, None] * K + (1.0 - torch.cos(theta))[..., None, None] * (K @ K) + return torch.where(theta[..., None, None] < 1e-7, I, R) + +def scipy_Log(R_gpu): + return torch.from_numpy(Rotation.from_matrix(R_gpu.cpu().numpy()).as_rotvec()).to(R_gpu.device) + +def scipy_Exp(v_gpu): + return torch.from_numpy(Rotation.from_rotvec(v_gpu.cpu().numpy()).as_matrix()).float().to(v_gpu.device) + +# ─── Benchmark 1: Attention ──────────────────────────────────────────────────── + +def bench_attention(): + print("\n" + "="*66) + print("Flash Attention (F.scaled_dot_product_attention vs einsum)") + print("="*66) + + configs = [ + (1, 4, 64, 32, "L=64 (short)"), + (1, 4, 200, 32, "L=200 (typical)"), + (1, 4, 500, 32, "L=500 (long)"), + (4, 4, 200, 32, "L=200 batch=4"), + ] + + for label, (B, h, L, d) in [(c[-1], c[:-1]) for c in configs]: + q = torch.randn(B, h, L, d, device=DEVICE, dtype=DTYPE) + k = torch.randn(B, h, L, d, device=DEVICE, dtype=DTYPE) + v = torch.randn(B, h, L, d, device=DEVICE, dtype=DTYPE) + bias = torch.randn(B, h, L, L, device=DEVICE, dtype=DTYPE) + scale = d ** -0.5 + + def old_plain(): + a = torch.einsum('bhid,bhjd->bhij', q * scale, k) + a = F.softmax(a, dim=-1) + return torch.einsum('bhij,bhjd->bhid', a, v) + + def new_plain(): + return F.scaled_dot_product_attention(q, k, v) + + def old_bias(): + a = torch.einsum('bhid,bhjd->bhij', q * scale, k) + bias + a = F.softmax(a, dim=-1) + return torch.einsum('bhij,bhjd->bhid', a, v) + + def new_bias(): + return F.scaled_dot_product_attention(q, k, v, attn_mask=bias) + + t_op = timer(old_plain); t_np = timer(new_plain) + t_ob = timer(old_bias); t_nb = timer(new_bias) + print(f" {label:<18} plain {t_op:.3f}ms -> {t_np:.3f}ms ({t_op/t_np:.1f}x) " + f"biased {t_ob:.3f}ms -> {t_nb:.3f}ms ({t_ob/t_nb:.1f}x)") + +# ─── Benchmark 2: SO(3) ─────────────────────────────────────────────────────── + +def bench_so3(): + print("\n" + "="*66) + print("SO(3) Log+Exp: scipy CPU-roundtrip vs torch-native (on GPU)") + print("Note: isolated timings; GPU pipeline benefit not captured here") + print("="*66) + + for N in [50, 200, 500]: + vn = np.random.randn(N, 3).astype(np.float32) + vn = vn / np.linalg.norm(vn, axis=1, keepdims=True) * np.random.uniform(0.01, 3.0, (N, 1)) + Rn = Rotation.from_rotvec(vn.astype(np.float64)).as_matrix().astype(np.float32) + R_gpu = torch.from_numpy(Rn).to(DEVICE) + v_gpu = torch.from_numpy(vn).to(DEVICE) + + t_sc = timer(lambda: (scipy_Log(R_gpu), scipy_Exp(v_gpu)), warmup=10, trials=200) + t_th = timer(lambda: (Log_torch(R_gpu), Exp_torch(v_gpu)), warmup=10, trials=200) + note = "(*)" if t_th > t_sc else "" + print(f" N={N:<5} scipy {t_sc:.3f}ms torch {t_th:.3f}ms {t_sc/t_th:.1f}x{note}") + + print(" (*) For small N, GPU kernel launch overhead exceeds scipy cost in isolation.") + print(" The torch path eliminates the CPU sync point that stalls the GPU pipeline") + print(" between denoising steps, and keeps the full trajectory on-device.") + +# ─── Benchmark 3: GPU pipeline stall ────────────────────────────────────────── + +def bench_pipeline_stall(): + print("\n" + "="*66) + print("GPU pipeline stall from .cpu() transfer (100-step denoising loop)") + print("="*66) + N = 200 + Rn = Rotation.from_rotvec(np.random.randn(N, 3)).as_matrix().astype(np.float32) + R_gpu = torch.from_numpy(Rn).to(DEVICE) + v_gpu = torch.randn(N, 3, device=DEVICE) + steps = 100 + + if DEVICE == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(steps): + dummy = torch.randn(N, 3, device=DEVICE) * 0.01 + v_gpu + scipy_Log(R_gpu + 0) # forces CPU sync each step + scipy_Exp(dummy) + torch.cuda.synchronize() + t_scipy = (time.perf_counter() - t0) * 1000 + + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(steps): + dummy = torch.randn(N, 3, device=DEVICE) * 0.01 + v_gpu + Log_torch(R_gpu + 0) # stays on GPU + Exp_torch(dummy) + torch.cuda.synchronize() + t_torch = (time.perf_counter() - t0) * 1000 + + print(f" {steps}-step loop (N={N} residues):") + print(f" scipy (CPU sync each step): {t_scipy:.1f}ms") + print(f" torch (no sync): {t_torch:.1f}ms") + print(f" speedup: {t_scipy/t_torch:.1f}x") + +# ─── Accuracy ───────────────────────────────────────────────────────────────── + +def check_accuracy(): + """Round-trip test: R -> Log_torch -> Exp_torch -> R. + A perfect Log would recover R exactly; errors here are the combined + Log+Exp error against the ground-truth rotation. + """ + import sys; sys.path.insert(0, '.') + from rfdiffusion.igso3 import Log_torch as Log_igso3, Exp_torch as Exp_igso3 + + print("\n" + "="*66) + print("Accuracy: R -> Log_torch -> Exp_torch -> R (round-trip |dR|)") + print("="*66) + + np.random.seed(0) + N = 100000 + mags = np.random.uniform(0.01, math.pi, N) + axes = np.random.randn(N, 3); axes /= np.linalg.norm(axes, axis=1, keepdims=True) + vn = (axes * mags[:, None]).astype(np.float32) + Rn = Rotation.from_rotvec(vn.astype(np.float64)).as_matrix().astype(np.float32) + R_gpu = torch.from_numpy(Rn).to(DEVICE) + R_rec = Exp_igso3(Log_igso3(R_gpu)) + errs = (R_rec - R_gpu).abs().amax(dim=(-1, -2)).cpu().numpy() + print(f" Log_torch + Exp_torch, full [0,pi]: max|dR|={errs.max():.2e} mean={errs.mean():.2e}") + for lo, hi, label in [(0,2,'0..2'), (2,3,'2..3'), (3,math.pi-0.01,'3..pi-0.01'), (math.pi-0.01,math.pi,'pi-0.01..pi')]: + m = (mags>=lo)&(mags<=hi) + if m.sum(): print(f" [{label:<12}] N={m.sum():6d} max={errs[m].max():.2e} mean={errs[m].mean():.2e}") + + log_s = torch.from_numpy(Rotation.from_matrix(Rn.astype(np.float64)).as_rotvec().astype(np.float32)).to(DEVICE) + R_sc = Exp_igso3(log_s) + errs_s = (R_sc - R_gpu).abs().amax(dim=(-1,-2)).cpu().numpy() + print(f" scipy Log + Exp_torch baseline: max|dR|={errs_s.max():.2e} mean={errs_s.mean():.2e}") + +# ─── Main ───────────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + print(f"\nDevice : {DEVICE}") + if DEVICE == "cuda": + print(f"GPU : {torch.cuda.get_device_name(0)}") + print(f"PyTorch: {torch.__version__}") + + check_accuracy() + bench_attention() + bench_so3() + if DEVICE == "cuda": + bench_pipeline_stall() + + print("\nDone.")