1111
1212from sensa .layers import mask_utils
1313from sensa .layers .dyt import DyT
14- from sensa .layers .encoder import Encoder
14+ from sensa .layers .encoder import Encoder2
1515from sensa .layers .last_pool import LastPool
1616from sensa .models .base import BaseModel
1717from sensa .models .registry import register_model
@@ -157,10 +157,12 @@ class VIT(BaseModel):
157157 Defaults to "token".
158158 last_stride (int, optional):
159159 Stride for the stem's final downsampling block. Defaults to 4.
160+ act_layer (Callable[..., torch.nn.Module]):
161+ Activation layer for encoder. Defaults to torch.nn.GELU.
160162 norm_layer (Callable[..., nn.Module] | str, optional):
161163 Normalization layer for encoder. Defaults to LayerNorm(eps=1e-6).
162- use_sincos_pos_token (bool, optional ):
163- Whether to use fixed sinusoidal positional embeddings . Defaults to False .
164+ pos_token (Literal["learned", "sincos", "rope"] ):
165+ Positional token type . Defaults to "learned" .
164166 """
165167
166168 def __init__ (
@@ -177,8 +179,9 @@ def __init__(
177179 first_stride : int = 2 ,
178180 last_pool : Literal ["avg" , "full" , "half" , "token" , None ] = "token" ,
179181 last_stride : int = 4 ,
182+ act_layer : Callable [..., torch .nn .Module ] = torch .nn .GELU ,
180183 norm_layer : Callable [..., torch .nn .Module ] | str | None = None ,
181- use_sincos_pos_token : bool = False ,
184+ pos_token : Literal [ "learned" , "sincos" , "rope" ] = "learned" ,
182185 ):
183186 super ().__init__ ()
184187 self .image_size = torch .nn .modules .utils ._pair (image_size )
@@ -216,19 +219,20 @@ def __init__(
216219 self .class_token = torch .nn .Parameter (torch .zeros (1 , 1 , hidden_dim ))
217220 extra_tokens += 1
218221
219- self .encoder = Encoder (
222+ self .encoder = Encoder2 (
220223 size = self .stem_size ,
221224 extra_tokens = extra_tokens ,
222225 num_layers = num_layers ,
223226 num_heads = num_heads ,
224227 hidden_dim = hidden_dim ,
225228 mlp_dim = mlp_dim ,
226229 dropout = 0.0 ,
227- attention_dropout = 0.0 ,
230+ act_layer = act_layer ,
228231 norm_layer = norm_layer ,
232+ pos_token = pos_token ,
229233 )
230- if use_sincos_pos_token :
231- self .encoder .use_sincos_pos_token (extra_tokens = int (last_pool == "token" ), size = self .stem_size )
234+ # if pos_token == "sincos" :
235+ # self.encoder.use_sincos_pos_token(extra_tokens=int(last_pool == "token"), size=self.stem_size)
232236 self .seq_length = self .encoder .seq_length
233237
234238 if self .mask_ratio > 0 :
@@ -310,7 +314,11 @@ def param_groups(self) -> list[dict[str, Any]]:
310314 for i in range (0 , len (self .encoder .layers ), 2 ):
311315 groups += self ._param_groups (self .encoder .layers [slice (i , i + 2 )])
312316 self ._param_groups (self .encoder .ln , groups = groups [- 2 :])
313- if isinstance (self .encoder .pos_token , torch .nn .Parameter ) and self .encoder .pos_token .requires_grad :
317+ if (
318+ hasattr (self .encoder , "pos_token" )
319+ and isinstance (self .encoder .pos_token , torch .nn .Parameter )
320+ and self .encoder .pos_token .requires_grad
321+ ):
314322 groups [- 1 ]["params" ].append (self .encoder .pos_token )
315323 if (
316324 hasattr (self , "class_token" )
0 commit comments