Skip to content

Commit 8b55a8f

Browse files
authored
Update real_time_encoder_transformer.py
1 parent 21c18c2 commit 8b55a8f

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

neural_network/real_time_encoder_transformer.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,13 @@ def forward(self, time_steps: np.ndarray) -> np.ndarray:
3939
(1, 3, 4)
4040
"""
4141

42+
4243
linear = self.w0 * time_steps + self.b0
43-
periodic = np.sin(time_steps * self.w[:, None, :] + self.b[:, None, :])
44+
periodic = np.sin(time_steps * self.w[None, None, :] + self.b[None, None, :])
4445
return np.concatenate([linear, periodic], axis=-1)
4546

4647

48+
4749
# -------------------------------
4850
# 🔹 LayerNorm
4951
# -------------------------------
@@ -267,15 +269,26 @@ def __init__(
267269
def forward(self, eeg_data: np.ndarray) -> np.ndarray:
268270
"""
269271
>>> model = EEGTransformer(4, 2, 8, 2, seed=0)
270-
>>> x = np.ones((1, 3, 4))
272+
>>> x = np.ones((1, 3, 1))
271273
>>> out = model.forward(x)
272274
>>> out.shape
273275
(1, 1)
274276
"""
277+
# Ensure input shape is (batch, seq_len, 1)
278+
if eeg_data.shape[-1] != 1:
279+
eeg_data = eeg_data[..., :1]
280+
281+
# Time2Vec positional encoding
275282
x = self.time2vec.forward(eeg_data)
283+
284+
# Transformer encoder
276285
x = self.encoder.forward(x)
286+
287+
# Attention pooling
277288
x = self.pooling.forward(x)
278-
out = np.tensordot(x, self.w_out, axes=([1], [0])) + self.b_out
289+
290+
# Final linear layer
291+
out = np.dot(x, self.w_out) + self.b_out # shape (batch, output_dim)
279292
return out
280293

281294

0 commit comments

Comments
 (0)