Skip to content

Commit fbb5ab9

Browse files
committed
Encoder2 is now default
1 parent 23f698e commit fbb5ab9

6 files changed

Lines changed: 35 additions & 18 deletions

File tree

examples/finetune_minivit_imagenet.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ backbone:
3232
first_stride: 2
3333
last_pool: half
3434
last_stride: 4
35-
use_sincos_pos_token: True
35+
pos_token: "sincos"
3636

3737
mode: full
3838
skip_freeze_prefixes:

examples/mae_minivit_imagenet.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ mae_encoder:
2323
first_stride: 2
2424
last_pool: half
2525
last_stride: 4
26-
use_sincos_pos_token: False
26+
pos_token: "sincos"
2727

2828
mae_decoder:
2929
name: MAEDecoder
@@ -37,6 +37,7 @@ mae_decoder:
3737
mlp_dim: 2048
3838
num_layers: 4
3939
num_heads: 16
40+
pos_token: "sincos"
4041

4142
loss:
4243
name: MSELoss

src/sensa/models/mae_decoder.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from collections.abc import Callable
22
from functools import partial
3+
from typing import Literal
34

45
import torch
56

6-
from sensa.layers.encoder import Encoder
7+
from sensa.layers.encoder import Encoder2
78
from sensa.models.base import BaseModel
89
from sensa.models.registry import register_model
910

@@ -29,8 +30,12 @@ class MAEDecoder(BaseModel):
2930
Number of transformer layers in the decoder.
3031
num_heads (int):
3132
Number of attention heads in each decoder layer.
33+
act_layer (Callable[..., torch.nn.Module]):
34+
Activation layer for the decoder. Defaults to torch.nn.GELU.
3235
norm_layer (Callable[..., torch.nn.Module], optional):
3336
Constructor for the normalization layer. Defaults to `partial(torch.nn.LayerNorm, eps=1e-6)`.
37+
pos_token (Literal["learned", "sincos", "rope"]):
38+
Positional token type. Defaults to "sincos".
3439
"""
3540

3641
def __init__(
@@ -43,7 +48,9 @@ def __init__(
4348
mlp_dim: int,
4449
num_layers: int,
4550
num_heads: int,
51+
act_layer: Callable[..., torch.nn.Module] = torch.nn.GELU,
4652
norm_layer: Callable[..., torch.nn.Module] | None = None,
53+
pos_token: Literal["learned", "sincos", "rope"] = "sincos",
4754
):
4855
super().__init__()
4956
self.image_size = image_size
@@ -55,18 +62,19 @@ def __init__(
5562
torch.nn.Identity() if encoder_dim == decoder_dim else torch.nn.Linear(encoder_dim, decoder_dim)
5663
)
5764
# build the decoder transformer
58-
self.decoder = Encoder(
65+
self.decoder = Encoder2(
5966
size=self.stem_size,
6067
extra_tokens=0,
6168
num_layers=num_layers,
6269
num_heads=num_heads,
6370
hidden_dim=decoder_dim,
6471
mlp_dim=mlp_dim,
6572
dropout=0.0,
66-
attention_dropout=0.0,
73+
act_layer=act_layer,
6774
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6) if norm_layer is None else norm_layer,
75+
pos_token=pos_token,
6876
)
69-
self.decoder.use_sincos_pos_token(extra_tokens=0, size=self.stem_size)
77+
# self.decoder.use_sincos_pos_token(extra_tokens=0, size=self.stem_size)
7078
# projection head to map decoder outputs back to patch pixels
7179
self.predict = torch.nn.Linear(decoder_dim, patch_size * patch_size * channels)
7280

src/sensa/models/vit.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from sensa.layers import mask_utils
1313
from sensa.layers.dyt import DyT
14-
from sensa.layers.encoder import Encoder
14+
from sensa.layers.encoder import Encoder2
1515
from sensa.layers.last_pool import LastPool
1616
from sensa.models.base import BaseModel
1717
from sensa.models.registry import register_model
@@ -157,10 +157,12 @@ class VIT(BaseModel):
157157
Defaults to "token".
158158
last_stride (int, optional):
159159
Stride for the stem's final downsampling block. Defaults to 4.
160+
act_layer (Callable[..., torch.nn.Module]):
161+
Activation layer for encoder. Defaults to torch.nn.GELU.
160162
norm_layer (Callable[..., nn.Module] | str, optional):
161163
Normalization layer for encoder. Defaults to LayerNorm(eps=1e-6).
162-
use_sincos_pos_token (bool, optional):
163-
Whether to use fixed sinusoidal positional embeddings. Defaults to False.
164+
pos_token (Literal["learned", "sincos", "rope"]):
165+
Positional token type. Defaults to "learned".
164166
"""
165167

166168
def __init__(
@@ -177,8 +179,9 @@ def __init__(
177179
first_stride: int = 2,
178180
last_pool: Literal["avg", "full", "half", "token", None] = "token",
179181
last_stride: int = 4,
182+
act_layer: Callable[..., torch.nn.Module] = torch.nn.GELU,
180183
norm_layer: Callable[..., torch.nn.Module] | str | None = None,
181-
use_sincos_pos_token: bool = False,
184+
pos_token: Literal["learned", "sincos", "rope"] = "learned",
182185
):
183186
super().__init__()
184187
self.image_size = torch.nn.modules.utils._pair(image_size)
@@ -216,19 +219,20 @@ def __init__(
216219
self.class_token = torch.nn.Parameter(torch.zeros(1, 1, hidden_dim))
217220
extra_tokens += 1
218221

219-
self.encoder = Encoder(
222+
self.encoder = Encoder2(
220223
size=self.stem_size,
221224
extra_tokens=extra_tokens,
222225
num_layers=num_layers,
223226
num_heads=num_heads,
224227
hidden_dim=hidden_dim,
225228
mlp_dim=mlp_dim,
226229
dropout=0.0,
227-
attention_dropout=0.0,
230+
act_layer=act_layer,
228231
norm_layer=norm_layer,
232+
pos_token=pos_token,
229233
)
230-
if use_sincos_pos_token:
231-
self.encoder.use_sincos_pos_token(extra_tokens=int(last_pool == "token"), size=self.stem_size)
234+
# if pos_token == "sincos":
235+
# self.encoder.use_sincos_pos_token(extra_tokens=int(last_pool == "token"), size=self.stem_size)
232236
self.seq_length = self.encoder.seq_length
233237

234238
if self.mask_ratio > 0:
@@ -310,7 +314,11 @@ def param_groups(self) -> list[dict[str, Any]]:
310314
for i in range(0, len(self.encoder.layers), 2):
311315
groups += self._param_groups(self.encoder.layers[slice(i, i + 2)])
312316
self._param_groups(self.encoder.ln, groups=groups[-2:])
313-
if isinstance(self.encoder.pos_token, torch.nn.Parameter) and self.encoder.pos_token.requires_grad:
317+
if (
318+
hasattr(self.encoder, "pos_token")
319+
and isinstance(self.encoder.pos_token, torch.nn.Parameter)
320+
and self.encoder.pos_token.requires_grad
321+
):
314322
groups[-1]["params"].append(self.encoder.pos_token)
315323
if (
316324
hasattr(self, "class_token")

tests/samples/mae_vit.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ mae_encoder:
2424
last_pool: half
2525
last_stride: 4
2626
norm_layer: dyt
27-
use_sincos_pos_token: True
27+
pos_token: "sincos"
2828

2929
mae_decoder:
3030
name: MAEDecoder

tests/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_vit_features():
4747
num_classes=None,
4848
in_channels=3,
4949
last_pool=None,
50-
use_sincos_pos_token=True,
50+
pos_token="sincos",
5151
)
5252
output = model(torch.randn(1, 3, 128, 128))
5353
assert output.shape[-1] == model.hidden_dim, f"output shape must be {output.shape}"
@@ -67,7 +67,7 @@ def test_vit_features_with_pool():
6767
num_classes=None,
6868
in_channels=3,
6969
last_pool="half",
70-
use_sincos_pos_token=True,
70+
pos_token="rope",
7171
)
7272
output = model(torch.randn(1, 3, 128, 128))
7373
size = model.hidden_dim * (model.stem_size[0] // 2) * (model.stem_size[1] // 2)

0 commit comments

Comments
 (0)