diff --git a/point_e/models/perceiver.py b/point_e/models/perceiver.py index 9e7c730..125e919 100644 --- a/point_e/models/perceiver.py +++ b/point_e/models/perceiver.py @@ -1,11 +1,10 @@ import math -from typing import Optional - import torch import torch.nn as nn from .checkpoint import checkpoint from .transformer import MLP, init_linear +from typing import Optional class MultiheadCrossAttention(nn.Module): @@ -36,10 +35,11 @@ def __init__( init_linear(self.c_proj, init_scale) def forward(self, x, data): - x = self.c_q(x) - data = self.c_kv(data) - x = checkpoint(self.attention, (x, data), (), True) - x = self.c_proj(x) + with torch.no_grad(): + x = self.c_q(x) + data = self.c_kv(data) + x = checkpoint(self.attention, (x, data), (), True) + x = self.c_proj(x) return x @@ -52,18 +52,19 @@ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_da self.n_data = n_data def forward(self, q, kv): - _, n_ctx, _ = q.shape - bs, n_data, width = kv.shape - attn_ch = width // self.heads // 2 - scale = 1 / math.sqrt(math.sqrt(attn_ch)) - q = q.view(bs, n_ctx, self.heads, -1) - kv = kv.view(bs, n_data, self.heads, -1) - k, v = torch.split(kv, attn_ch, dim=-1) - weight = torch.einsum( - "bthc,bshc->bhts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards - wdtype = weight.dtype - weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + with torch.no_grad(): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) @@ -84,6 +85,10 @@ def __init__( if data_width is None: data_width = width + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) + self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) + self.attn = MultiheadCrossAttention( device=device, dtype=dtype, @@ -93,16 +98,15 @@ def __init__( data_width=data_width, init_scale=init_scale, ) - self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) - self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) - self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) def forward(self, x: torch.Tensor, data: torch.Tensor): - x = x + self.attn(self.ln_1(x), self.ln_2(data)) - x = x + self.mlp(self.ln_3(x)) - return x + with torch.no_grad(): + # Normalize input tensors and pass them through the attention and MLP layers + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x class SimplePerceiver(nn.Module): """ @@ -141,6 +145,7 @@ def __init__( ) def forward(self, x: torch.Tensor, data: torch.Tensor): - for block in self.resblocks: - x = block(x, data) + with torch.no_grad(): + for block in self.resblocks: + x = block(x, data) return x