@@ -39,11 +39,13 @@ def forward(self, time_steps: np.ndarray) -> np.ndarray:
3939 (1, 3, 4)
4040 """
4141
42+
4243 linear = self .w0 * time_steps + self .b0
43- periodic = np .sin (time_steps * self .w [: , None , :] + self .b [: , None , :])
44+ periodic = np .sin (time_steps * self .w [None , None , :] + self .b [None , None , :])
4445 return np .concatenate ([linear , periodic ], axis = - 1 )
4546
4647
48+
4749# -------------------------------
4850# 🔹 LayerNorm
4951# -------------------------------
@@ -267,15 +269,26 @@ def __init__(
267269 def forward (self , eeg_data : np .ndarray ) -> np .ndarray :
268270 """
269271 >>> model = EEGTransformer(4, 2, 8, 2, seed=0)
270- >>> x = np.ones((1, 3, 4 ))
272+ >>> x = np.ones((1, 3, 1 ))
271273 >>> out = model.forward(x)
272274 >>> out.shape
273275 (1, 1)
274276 """
277+ # Ensure input shape is (batch, seq_len, 1)
278+ if eeg_data .shape [- 1 ] != 1 :
279+ eeg_data = eeg_data [..., :1 ]
280+
281+ # Time2Vec positional encoding
275282 x = self .time2vec .forward (eeg_data )
283+
284+ # Transformer encoder
276285 x = self .encoder .forward (x )
286+
287+ # Attention pooling
277288 x = self .pooling .forward (x )
278- out = np .tensordot (x , self .w_out , axes = ([1 ], [0 ])) + self .b_out
289+
290+ # Final linear layer
291+ out = np .dot (x , self .w_out ) + self .b_out # shape (batch, output_dim)
279292 return out
280293
281294
0 commit comments