-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathssl.html
More file actions
794 lines (691 loc) · 53.9 KB
/
ssl.html
File metadata and controls
794 lines (691 loc) · 53.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Self-Supervised Learning for Time Series</title>
<style>
body {
font-family: sans-serif;
max-width: 900px;
margin: 0 auto;
padding: 2rem;
line-height: 1.6;
color: #333;
}
h1, h2, h3, h4 {
color: #2c3e50; /* Darker shade for better contrast */
margin-top: 1.5em;
margin-bottom: 0.5em;
}
h1 { font-size: 2.5em; border-bottom: 2px solid #3498db; padding-bottom: 0.3em;}
h2 { font-size: 2em; border-bottom: 1px solid #bdc3c7; padding-bottom: 0.2em;}
h3 { font-size: 1.5em; }
h4 { font-size: 1.2em; color: #555;}
nav { margin-bottom: 30px; padding: 10px; background: #ecf0f1; border: 1px solid #bdc3c7; border-radius: 4px;}
nav ul { list-style: none; padding: 0; }
nav li { display: inline-block; margin-right: 15px; }
nav a { text-decoration: none; color: #3498db; font-weight: bold;}
nav a:hover { text-decoration: underline; color: #2980b9;}
pre {
background: #f8f9f9; /* Lighter background for code blocks */
padding: 1rem;
overflow-x: auto;
border: 1px solid #e1e4e8; /* Softer border */
border-left: 4px solid #3498db; /* Accent border */
border-radius: 4px;
font-size: 0.9em;
}
code {
font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, Courier, monospace;
}
/* For inline code */
p > code, li > code, table td > code {
background: #e8eaed;
padding: 0.2em 0.4em;
border-radius: 3px;
font-size: 0.85em;
}
pre code { /* Reset for code inside pre, already handled by pre styling */
background: none;
padding: 0;
font-size: 1em; /* Ensure pre's font size is inherited */
}
ul {
padding-left: 20px;
}
li {
margin-bottom: 0.5em;
}
strong {
color: #2980b9;
}
</style>
<script type="text/javascript" async
src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-MML-AM_CHTML">
</script>
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [['$','$'], ['\\(','\\)']],
displayMath: [['$$','$$'], ['\\[','\\]']],
processEscapes: true
}
});
</script>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/monokai.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js"></script>
<script>hljs.highlightAll();</script>
<script src="https://cdn.jsdelivr.net/npm/torch@2.3.0/dist/torch.min.js"></script>
</head>
<body>
<nav>
<ul>
<li><a href="index.html">Home</a></li>
</ul>
</nav>
<h1>Self-Supervised Learning (SSL) for Time Series Data</h1>
<p><b>Self-Supervised Learning (SSL)</b> is a type of representation learning where models learn from unlabeled data by generating their own training signals (pseudo-labels) using **pretext tasks**. The learned features are then utilized for downstream tasks. This tutorial will focus on SSL applied to time series data.</p>
<p>SSL asks the model to predict one part of the data from another. This forces the model to learn meaningful internal data structures without requiring human labels. A typical SSL pipeline involves:</p>
<ol>
<li>Taking raw, unlabeled data.</li>
<li>Creating a pretext task the model must solve (e.g., mask prediction, contrastive matching).</li>
<li>Learning robust features by solving that task.</li>
<li>Applying the learned features to a downstream task, such as classification or detection.</li>
</ol>
<p>There are five primary categories of pretext tasks:</p>
<ul>
<li><b>Reconstruction</b>: Predicting missing parts of the input, like in Masked Autoencoders (MAE) and traditional Autoencoders.</li>
<li><b>Contrastive Learning</b>: Bringing similar "views" of data closer together in the latent space while pushing dissimilar views farther apart (e.g., SimCLR, MoCo, CPC).</li>
<li><b>Pretext Prediction</b>: Predicting transformations applied to the input, or predicting the next step in a sequence (e.g., temporal order prediction).</li>
<li><b>Masked Prediction</b>: Predicting randomly masked input tokens or patches (e.g., BERT, MAE).</li>
<li><b>Interpolation</b>: Predicting values between two signals (e.g., Mixup).</li>
</ul>
<p>While SSL has been primarily useful in building Large Language Models (LLMs), it has recently found its way into sensor data applications (e.g., <a href="https://arxiv.org/abs/2410.13638" target="_blank">https://arxiv.org/abs/2410.13638</a>). Time series data, such as EEG and ECoG, contains rich temporal and spatial patterns. SSL can learn generalized representations from this data that transfer well to new tasks and handle missing or noisy data more robustly.</p>
<p>In a nutshell, SSL removes reliance on human labels by learning rich and transferable features through pre-training on large, uncurated datasets.</p>
<hr>
<h2>Reconstruction-Based SSL</h2>
<p>If a model can accurately reconstruct missing or masked input, it implies it has learned meaningful patterns in the data. Reconstruction tasks focus the model on learning temporal dependencies, local and global structure, and feature relationships within the data.</p>
<h3>Autoencoders</h3>
<p><b>Autoencoders</b> are commonly used for reconstruction-based SSL. A simple autoencoder compresses the input into a lower-dimensional latent representation and then reconstructs the original input from this representation, typically using Mean Squared Error (MSE) loss.</p>
<p>$$ \text{x} \longrightarrow \text{Encoder} \longrightarrow \text{z} \longrightarrow \text{Decoder} \longrightarrow \hat{\text{x}} $$</p>
<p>Masked Autoencoders (MAE) extend this by partitioning the input into patches, randomly masking a subset, encoding *only* the visible patches, and having the decoder reconstruct *all* patches. However, the loss is computed only on the masked portions using MSE.</p>
<p>For time series data, the input is a segment with shape <code>[channels x time]</code>. The encoder compresses this into a latent representation, and the decoder reconstructs the original signal.</p>
<p>If the input <code>x</code> has shape <code>[batch_size, channels, time]</code>, we can use <code>Conv1d</code> layers in the encoder to downsample and extract temporal features, learning a compact latent representation (e.g., 32 dimensions). Then, <code>ConvTranspose1d</code> layers in the decoder can reconstruct the original input shape.</p>
<p>$$ \text{Input x} \longrightarrow \text{Encoder (Conv1D / Linear)} \longrightarrow \text{Latent z} \longrightarrow \text{Decoder} \longrightarrow \text{Reconstructed } \hat{\text{x}} $$</p>
<h4>What is ConvTranspose1d?</h4>
<p>Before diving back into the autoencoder, let's briefly understand <code>ConvTranspose1d</code>. It is the inverse operation of <code>Conv1d</code>. While <code>Conv1d</code> typically downsamples the time dimension, <code>ConvTranspose1d</code> upsamples it. It's also known as deconvolution, transpose convolution, or learned upsampling. For example, if you downsample with <code>stride = 2</code>, you can upsample by a factor of 2 using <code>ConvTranspose1d</code>.</p>
<p>Unlike simple upsampling (which isn't learnable), <code>ConvTranspose1d</code> achieves upsampling by effectively inserting zeros between elements and then applying a learnable kernel to fill in these missing parts. In the provided scratch implementation, original values are inserted into zeros to create space for the model to learn what values should fill the gaps. A standard convolution is then run on this upsampled signal with appropriate padding, allowing the output to expand to the desired size. During training, the model learns kernel weights to reconstruct meaningful values in the upsampled output, spreading and blending values into a longer signal.</p>
<pre><code class="language-python">
class ConvTranspose1dScratch(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super().__init__()
self.stride = stride
self.kernel_size = kernel_size
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size))
self.bias = nn.Parameter(torch.zeros(out_channels))
def forward(self, x):
B, C_in, L_in = x.shape
C_out = self.weight.shape[0]
# Step 1: Insert zeros between time steps
L_upsampled = (L_in - 1) * self.stride + 1
x_upsampled = torch.zeros(B, C_in, L_upsampled, device=x.device)
x_upsampled[:, :, ::self.stride] = x # insert zeros between steps
# Step 2: Perform normal convolution over upsampled signal
# This is regular F.conv1d with flipped weights (like in transposed conv)
out = F.conv1d(x_upsampled, self.weight, self.bias, stride=1, padding=self.kernel_size - 1)
return out
</code></pre>
<p>Now, back to our autoencoder. The encoder extracts low and high-level features into a shorter and deeper representation, forming a latent space that captures the signal’s essential structure. The decoder then effectively doubles the time steps in each layer by refining features. It learns how to fill the signal and recover the original time series from the latent feature. The network learns to preserve the shape and structure of the original signal using MSE loss.</p>
<pre><code class="language-python">
import torch
import torch.nn as nn
class TimeSeriesAutoencoder(nn.Module):
def __init__(self, input_channels=1, latent_dim=64):
super().__init__()
# Encoder: downsample with Conv1d
self.encoder = nn.Sequential(
nn.Conv1d(input_channels, 16, kernel_size=5, stride=2, padding=2), # [B, 16, T/2]
nn.ReLU(),
nn.Conv1d(16, 32, kernel_size=5, stride=2, padding=2), # [B, 32, T/4]
nn.ReLU(),
nn.Conv1d(32, latent_dim, kernel_size=5, stride=2, padding=2), # [B, latent_dim, T/8]
nn.ReLU(),
)
# Decoder: upsample with ConvTranspose1d
self.decoder = nn.Sequential(
nn.ConvTranspose1d(latent_dim, 32, kernel_size=4, stride=2, padding=1), # [B, 32, T/4]
nn.ReLU(),
nn.ConvTranspose1d(32, 16, kernel_size=4, stride=2, padding=1), # [B, 16, T/2]
nn.ReLU(),
nn.ConvTranspose1d(16, input_channels, kernel_size=4, stride=2, padding=1), # [B, 1, T]
)
def forward(self, x):
z = self.encoder(x)
x_hat = self.decoder(z)
return x_hat
</code></pre>
<h3>Masked Autoencoders (MAE)</h3>
<p>An **MAE (Masked Autoencoder)** model only sees a portion of the input signal (e.g., 25%) and learns to reconstruct the entire signal, specifically focusing on the masked parts. The encoder takes the partial input, encodes it, and the decoder reconstructs the whole signal. The MSE loss is applied only to the masked values.</p>
<p>The first step in MAE is converting the time signal into non-overlapping patches, called **patch embeddings**. This is typically done using a <code>Conv1d</code> layer with a stride equal to the kernel size.</p>
<pre><code class="language-python">
# ----------------------
# Patch Embedding Module
# ----------------------
class PatchEmbed1D(nn.Module):
def __init__(self, in_channels, embed_dim, patch_size):
super().__init__()
self.patch_size = patch_size
# Conv1d with kernel and stride equal to patch_size extracts non-overlapping patches
self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
# x: [B, C, T]
x = self.proj(x) # [B, embed_dim, T//patch_size] - patches as features
x = x.transpose(1, 2) # [B, N_patches, embed_dim] - N_patches is now sequence length
return x
</code></pre>
<p>Next, we randomly keep a subset of these patches (e.g., 25%). We keep track of the original indices of both kept and masked patches to restore the original order later. The encoder transforms only the visible patches, resulting in an output of shape <code>[B, N_visible, embed_dim]</code>. The TransformerEncoderLayer typically contains multi-head self-attention, a feed-forward network, two layer norms, and residual connections. A full MAE Encoder would stack multiple such layers using <code>nn.TransformerEncoder</code>.</p>
<pre><code class="language-python">
# ----------------------
# MAE Encoder
# ----------------------
class MAEEncoder(nn.Module):
def __init__(self, embed_dim, depth):
super().__init__()
# Use a standard TransformerEncoderLayer as the base layer
layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=4, dim_feedforward=embed_dim*4)
# Stack multiple layers using TransformerEncoder
self.encoder = nn.TransformerEncoder(layer, num_layers=depth)
def forward(self, x):
# TransformerEncoder expects input as [SequenceLength, BatchSize, EmbedDim]
x = x.transpose(0, 1) # [N_patches, B, embed_dim]
x = self.encoder(x)
return x.transpose(0, 1) # [B, N_patches, embed_dim] - revert to original batch-first
</code></pre>
<p>The <code>random_masking</code> function selects a random subset of patches to keep. For each sample in the batch, it generates random noise, shuffles indices based on this noise, and selects the first <code>len_keep</code> indices as the visible patches. <code>torch.gather()</code> is used to select these patches from the input, resulting in <code>[B, N_keep, D]</code>.</p>
<pre><code class="language-python">
# ----------------------
# Random Masking
# ----------------------
def random_masking(x, mask_ratio):
B, N, D = x.shape # B: batch_size, N: num_patches, D: embed_dim
len_keep = int(N * (1 - mask_ratio))
noise = torch.rand(B, N, device=x.device) # Random noise for shuffling
ids_shuffle = torch.argsort(noise, dim=1) # Get shuffled indices
ids_restore = torch.argsort(ids_shuffle, dim=1) # Get indices to restore original order
ids_keep = ids_shuffle[:, :len_keep] # Select indices of patches to keep
# Use gather to select the visible patches
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).expand(-1, -1, D))
return x_masked, ids_restore, ids_keep
</code></pre>
<p>The decoder reconstructs masked patches using learnable mask tokens. The MAE decoder does not inherently know which patches were visible or masked in the encoder's input. Instead, it receives a sequence composed of the visible encoded patches concatenated with a set of learnable mask tokens for the missing patches. A shared, learnable mask token (e.g., <code>[1, 1, embed_dim]</code>) is expanded to <code>[B, N_masked, embed_dim]</code> and concatenated with the visible patches to form a full sequence (<code>[B, N_visible + N_masked, embed_dim]</code>). The <code>self.proj</code> layer maps the encoder output to the decoder dimension if they differ. Positional encodings are then added to each token in the reordered sequence, and the full sequence is passed through a Transformer decoder. This encourages the decoder to learn global dependencies between all patches (masked and visible). A final linear layer (<code>self.head</code>) maps the decoder output back to the original patch size (<code>[B, N_patches, patch_dim]</code>), predicting raw values for all patches (even visible ones), but the MSE loss is computed *only* on the masked indices.</p>
<pre><code class="language-python">
# ----------------------
# MAE Decoder
# ----------------------
class MAEDecoder(nn.Module):
def __init__(self, embed_dim, decoder_dim, patch_dim, depth, num_patches):
super().__init__()
self.mask_token = nn.Parameter(torch.randn(1, 1, decoder_dim))
self.pos_embed = nn.Parameter(torch.randn(1, num_patches, decoder_dim))
self.proj = nn.Linear(embed_dim, decoder_dim) # Project encoder output to decoder_dim if needed
# Decoder is also a TransformerEncoder stack
layer = nn.TransformerEncoderLayer(d_model=decoder_dim, nhead=4, dim_feedforward=decoder_dim*4)
self.decoder = nn.TransformerEncoder(layer, num_layers=depth)
self.head = nn.Linear(decoder_dim, patch_dim) # Maps decoder output to original patch pixel values
def forward(self, x_encoded, ids_restore):
B, N_vis, _ = x_encoded.shape
N_total = ids_restore.shape[1]
N_mask = N_total - N_vis
# Project encoded visible patches to decoder dimension
x_vis = self.proj(x_encoded)
# Expand mask tokens to batch size and number of masked patches
mask_tokens = self.mask_token.expand(B, N_mask, -1)
# Concatenate visible and mask tokens, then restore original order
x_full = torch.cat([x_vis, mask_tokens], dim=1)
x_full = torch.gather(x_full, dim=1, index=ids_restore.unsqueeze(-1).expand(-1, -1, x_full.size(-1)))
# Add positional embeddings to the full sequence
x_full = x_full + self.pos_embed[:, :N_total]
# Pass through the decoder Transformer
# TransformerEncoder expects (SeqLen, Batch, EmbedDim)
x_full = self.decoder(x_full.transpose(0,1)).transpose(0,1)
# Project to reconstruct patch pixel values
return self.head(x_full) # [B, N_patches, patch_dim]
</code></pre>
<p>And here's the full MAE model:</p>
<pre><code class="language-python">
# ----------------------
# Full MAE Model
# ----------------------
class MAEModel(nn.Module):
def __init__(self, in_channels=1, patch_size=16, embed_dim=128, encoder_depth=4, decoder_dim=64, decoder_depth=2):
super().__init__()
self.patch_embed = PatchEmbed1D(in_channels, embed_dim, patch_size)
self.encoder = MAEEncoder(embed_dim, encoder_depth)
# The decoder's head must output the flat pixel values of a patch
self.decoder = MAEDecoder(embed_dim, decoder_dim, patch_size * in_channels, decoder_depth, num_patches=400 // patch_size)
def forward(self, x, mask_ratio=0.75):
# 1. Convert input signal to patches
patches = self.patch_embed(x) # [B, N_patches, embed_dim]
# 2. Randomly mask patches
x_masked, ids_restore, ids_keep = random_masking(patches, mask_ratio)
# 3. Encode visible patches
encoded = self.encoder(x_masked)
# 4. Decode to reconstruct all patches
pred = self.decoder(encoded, ids_restore) # [B, N_patches, patch_dim]
return pred, patches, ids_restore # Return predictions, original patches, and restore indices
</code></pre>
<p>And here's the training loop. Let's assume we have a single-channel time series with 400 time steps, and we extract patches of size 16 samples. The decoder must rely on the context from visible patches to infer the values of the masked ones. This encourages a global understanding of the sequence, similar to BERT's masked language modeling, where the model infers masked words from context. In the loss calculation, we start with a mask indicating all patches are masked. Then, we "unmask" the visible patches by setting their mask values to zero at the corresponding <code>ids_restore</code> positions. Finally, we element-wise multiply the squared prediction error by this mask to compute the MSE loss *only* over the masked tokens.</p>
<pre><code class="language-python">
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import random # For augmentations
# Assuming you have a dataset: each sample is [C, T]
class DummyDataset(Dataset):
def __getitem__(self, idx):
return torch.randn(1, 400) # [C, T] - Example: 1 channel, 400 time steps
def __len__(self):
return 1000
dataset = DummyDataset()
loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = MAEModel(in_channels=1, patch_size=16) # num_patches will be 400 // 16 = 25
optimizer = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss(reduction='none') # Use reduction='none' to apply mask manually
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
for epoch in range(10):
model.train()
total_loss = 0.0
for x in loader:
x = x.to(device) # [B, C, T]
# Forward pass through MAE model
pred, target_patches, ids_restore = model(x, mask_ratio=0.75) # pred and target_patches are [B, N_patches, patch_dim]
# Prepare mask for loss calculation (only on masked patches)
B, N_patches, patch_dim = target_patches.shape
len_keep = int(N_patches * (1 - 0.75)) # 0.25 is keep_ratio
mask = torch.ones(B, N_patches, device=x.device) # Initialize mask to all ones (all masked)
# Set mask values to 0 for the *kept* patches (since ids_restore maps shuffled to original)
mask.scatter_(1, ids_restore[:, :len_keep], 0) # Scatter 0s at kept indices
mask = mask.unsqueeze(-1).expand(-1, -1, patch_dim) # Expand mask to match patch_dim
# Compute squared error for all patches
loss_all_patches = (pred - target_patches) ** 2
# Apply the mask: only count error for masked patches (where mask is 1)
loss = (loss_all_patches * mask).sum() / mask.sum() # Sum and then divide by number of masked elements
# Optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss / len(loader):.4f}")
</code></pre>
<hr>
<h2>Contrastive Learning</h2>
<p>The core idea of **Contrastive Learning** is to make similar samples more similar in the latent space and dissimilar samples less similar (i.e., push them farther apart). There are several variations of contrastive learning.</p>
<h3>SimCLR (Simple Contrastive Learning of Representations)</h3>
<p><b>SimCLR</b> involves applying two different random augmentations to the same input signal. Both augmented "views" are then passed through an encoder, and the goal is to maximize the similarity between their resulting latent representations. This forces the encoder to learn invariant features.</p>
<p>First, we define the augmentations. These augmentations preserve the identity of the signal but make it appear different enough to challenge the model. For time series, examples include:</p>
<ul>
<li><code>time_crop</code>: Takes a random window of the signal, cropping it to a percentage (e.g., 80%) of the original length.</li>
<li><code>time_jitter</code>: Adds small Gaussian noise to the signal, teaching the model robustness to sensor noise or signal drift.</li>
</ul>
<p>These augmentations create two different versions of the same signal, which serve as **positive pairs** for contrastive learning.</p>
<pre><code class="language-python">
import random # for time_crop
# ------------------------
# Time-series augmentations
# ------------------------
def time_crop(x, crop_ratio=0.8):
B, C, T = x.shape
new_T = int(T * crop_ratio)
if new_T >= T: # Ensure new_T is not larger than T if crop_ratio is too high or T is small
return x
start = random.randint(0, T - new_T)
return x[:, :, start:start + new_T]
def time_jitter(x, sigma=0.01):
return x + sigma * torch.randn_like(x)
def augment(x):
# Apply jitter first, then crop, or vice-versa, depending on desired effect
# Ensure dimensions match after crop for downstream models if using fixed input size
# For simplicity, let's ensure the output has the original time dimension
# This example needs to be carefully adapted for models expecting fixed T
# For now, let's assume time_crop handles padding or downstream adjusts.
# A more robust augment for fixed-size models would include padding after crop.
# For SimCLR, the original paper often resizes views to a common size for embedding.
# Here, for simplicity, we assume the encoder handles variable lengths or output is pooled.
cropped_x = time_crop(x)
return time_jitter(cropped_x)
</code></pre>
<p>SimCLR processes two augmented views of each input. Both views pass through the same encoder and then a projection head. The <code>SimCLRModel</code> includes a backbone encoder (e.g., <code>Conv1d</code> layers for feature extraction and downsampling, followed by <code>AdaptiveAvgPool1d</code> and <code>Flatten</code> to produce a fixed-length feature vector) and a <code>ProjectionHead</code>. The SimCLR paper demonstrated that projection heads improve contrastive learning, but for downstream tasks, the projection head is often discarded during fine-tuning.</p>
<pre><code class="language-python">
# ------------------------
# SimCLR Projection Head
# ------------------------
class ProjectionHead(nn.Module):
def __init__(self, in_dim, hidden_dim=256, out_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, out_dim)
)
def forward(self, x):
return self.net(x)
# ------------------------
# SimCLR Encoder + Projection
# ------------------------
class SimCLRModel(nn.Module):
def __init__(self, in_channels=1, encoder_dim=256):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv1d(in_channels, 64, kernel_size=5, stride=2, padding=2), # Output: [B, 64, T/2]
nn.ReLU(),
nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2), # Output: [B, 128, T/4]
nn.ReLU(),
nn.AdaptiveAvgPool1d(1), # Pools across time dimension to get [B, 128, 1]
nn.Flatten(), # Flattens to [B, 128]
nn.Linear(128, encoder_dim) # Projects to final encoder_dim [B, encoder_dim]
)
self.projector = ProjectionHead(encoder_dim)
def forward(self, x):
feat = self.encoder(x) # [B, encoder_dim]
proj = self.projector(feat) # [B, out_dim from ProjectionHead]
return F.normalize(proj, dim=1) # L2 normalize for cosine similarity
</code></pre>
<p>For each sample, its representation is L2 normalized so its norm becomes 1. Normalization is important because the contrastive loss uses cosine similarity, which simplifies computation and ensures all vectors lie on the unit hypersphere. This prevents the model from trivially increasing vector magnitudes instead of learning meaningful directions.</p>
<p>The SimCLR loss function, **NT-Xent (Normalized Temperature-scaled Cross-Entropy Loss)**, is at the heart of SimCLR. The temperature parameter is a critical hyperparameter that affects training dynamics; lower values make the softmax distribution sharper, leading to harder contrasts.</p>
<p>Given two representations, $z_1$ and $z_2$, from augmented views of the same batch (concatenated to shape <code>[2B, D]</code>), we compute pairwise similarities. The <code>sim</code> matrix, shaped <code>[2B, 2B]</code>, contains the similarity of each sample with every other sample. When $i=j$, it's the self-similarity. For contrastive learning, we want to compare a sample only to its positive pair, *excluding* self-similarity (which would be trivial). An identity matrix <code>mask</code> is used to set self-similarity scores to a very small negative number (e.g., `-9e15`) so that they are effectively zeroed out by softmax. <code>sim_targets</code> define which indices correspond to the positive pairs within the <code>[2B, 2B]</code> similarity matrix for the cross-entropy loss. The NT-Xent loss encourages $z_1$ to be close to $z_2$ (its positive pair) and far from all other negative samples in the batch.</p>
<pre><code class="language-python">
import torch.nn.functional as F
# ------------------------
# SimCLR Loss (NT-Xent)
# ------------------------
def nt_xent_loss(z1, z2, temperature=0.5):
B = z1.size(0) # Batch size
z = torch.cat([z1, z2], dim=0) # Concatenate both views: [2B, D]
# Compute pairwise cosine similarity: [2B, 2B]
# z.unsqueeze(1) -> [2B, 1, D]
# z.unsqueeze(0) -> [1, 2B, D]
# Result of cosine_similarity -> [2B, 2B]
sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)
sim /= temperature # Scale similarities by temperature
# Create labels for positive pairs
# For z1 samples (0 to B-1), their positive pair is at B to 2B-1
# For z2 samples (B to 2B-1), their positive pair is at 0 to B-1
# E.g., if B=2: labels = [2, 3, 0, 1]
labels = torch.arange(B, device=z.device) # [0, 1, ..., B-1]
sim_targets = torch.cat([labels + B, labels], dim=0) # [B, B+1, ..., 2B-1, 0, 1, ..., B-1]
# Create a mask to remove self-similarity (diagonal elements)
mask = torch.eye(2 * B, device=z.device).bool()
sim.masked_fill_(mask, -9e15) # Set self-similarity to a very small number
# Compute cross-entropy loss. `sim` are the logits, `sim_targets` are the true classes.
# The loss for each row `i` tries to classify `sim_targets[i]` as the correct positive.
loss = F.cross_entropy(sim, sim_targets)
return loss
</code></pre>
<p>And here is the training loop:</p>
<p>Batch size is crucial for SimCLR; larger batch sizes provide more negative examples, which generally improves performance. SimCLR often requires large batch sizes (e.g., 512 or more). Tricks like memory banks (used in MoCo) can help in storing a large number of negative samples efficiently.</p>
<h3>Momentum Contrast (MoCo)</h3>
<p><b>MoCo (Momentum Contrast)</b> was designed specifically to address the large batch size requirement of SimCLR. In NT-Xent loss, each sample is contrasted with every other sample in the current batch. While more negatives lead to a stronger learning signal, this makes large batch sizes impractical on single GPUs. MoCo decouples the batch size from the number of negative samples by maintaining a queue of past encoded samples, providing a large and dynamic set of negative examples. To ensure consistency of the embeddings in this queue, MoCo uses a **momentum encoder** for the "key" side of the contrastive pair:</p>
<p>$$ \theta_k \leftarrow m \cdot \theta_k + (1 - m) \cdot \theta_q $$</p>
<p>There are two encoders: a **query encoder** ($\theta_q$) and a **key encoder** ($\theta_k$). Parameters of the query encoder are updated normally via gradient descent, while parameters of the key encoder are updated slowly using an exponential moving average (EMA) of the query encoder's parameters. The queue holds key representations from this slowly changing key encoder to stabilize training. It acts like a FIFO (First-In, First-Out) buffer, storing old key embeddings and overwriting the oldest entries.</p>
<p>For a given input, an augmented view <code>x_q</code> is encoded by the query encoder, and another augmented view <code>x_k</code> is encoded by the momentum-updated key encoder (with gradients stopped for <code>x_k</code>'s path).</p>
<pre><code class="language-python">
import copy # For deepcopying the encoder
class MoCo(nn.Module):
def __init__(self, encoder, feature_dim=128, queue_size=1024, momentum=0.999):
super().__init__()
self.query_encoder = encoder # The encoder that gets updated by gradients
self.key_encoder = copy.deepcopy(encoder) # A copy for the key encoder
# Freeze key encoder parameters
for param in self.key_encoder.parameters():
param.requires_grad = False
# Register buffer for the queue of negative samples
self.register_buffer("queue", torch.randn(queue_size, feature_dim))
self.queue = F.normalize(self.queue, dim=1) # Normalize queue contents
# Pointer for the circular queue
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
self.momentum = momentum
self.queue_size = queue_size
self.feature_dim = feature_dim
@torch.no_grad() # This update happens without gradient tracking
def _momentum_update_key_encoder(self):
"""Momentum update of the key encoder"""
for param_q, param_k in zip(self.query_encoder.parameters(), self.key_encoder.parameters()):
param_k.data = self.momentum * param_k.data + (1 - self.momentum) * param_q.data
@torch.no_grad() # Queue operations do not require gradients
def _dequeue_and_enqueue(self, keys):
keys = keys.detach() # Detach keys from the computation graph
batch_size = keys.shape[0]
ptr = int(self.queue_ptr[0])
# Replace the oldest entries with current keys
# If batch_size + ptr > queue_size, it will wrap around
if ptr + batch_size > self.queue_size:
# Handle wrap-around
overflow = (ptr + batch_size) - self.queue_size
self.queue[ptr:] = keys[:self.queue_size - ptr]
self.queue[:overflow] = keys[self.queue_size - ptr:]
else:
self.queue[ptr:ptr + batch_size] = keys
ptr = (ptr + batch_size) % self.queue_size # Update pointer
self.queue_ptr[0] = ptr
def forward(self, x_q, x_k):
# Step 1: Encode
q = self.query_encoder(x_q) # Query embedding: [B, D]
q = F.normalize(q, dim=1) # Normalize query
with torch.no_grad(): # No gradient for key encoder path
self._momentum_update_key_encoder() # Update key encoder
k = self.key_encoder(x_k) # Key embedding: [B, D]
k = F.normalize(k, dim=1) # Normalize key
# Step 2: Compute logits
# Positives: dot product of query with its corresponding key
pos = torch.sum(q * k, dim=1, keepdim=True) # [B, 1]
# Negatives: dot product of query with all entries in the queue
neg = torch.matmul(q, self.queue.clone().detach().T) # [B, K] - clone() and detach() to ensure it's not part of graph
# Concatenate positive and negative logits
logits = torch.cat([pos, neg], dim=1) # [B, 1 + K]
# Labels for cross-entropy: first column (positives) is the correct class
labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)
# Step 3: Update queue with current batch's keys
self._dequeue_and_enqueue(k)
return logits, labels
</code></pre>
<p>The <code>_momentum_update_key_encoder</code> function smoothly updates the key encoder using the query encoder's weights. This is crucial because if old keys in the queue were produced by a completely static encoder, they would become inconsistent. MoCo's slow update allows the key encoder to track the query encoder, maintaining consistency for the queue embeddings.</p>
<p>The core contrastive step involves taking the dot product of matching <code>q</code> and <code>k</code> (positive pairs) and the dot product of <code>q</code> with all entries in the <code>queue</code> (negatives). These are concatenated to form the <code>logits</code> for the contrastive loss. The queue is then updated by storing the current batch's <code>k</code> embeddings and removing the oldest entries, acting as a circular buffer (<code>_dequeue_and_enqueue</code>).</p>
<p>The buffer (initialized with <code>self.register_buffer("queue", ...)</code> and <code>self.register_buffer("queue_ptr", ...)</code>) is what allows MoCo to scale contrastive learning without needing huge batch sizes. It is tracked by PyTorch as part of the model but does not receive gradients. The <code>_dequeue_and_enqueue</code> method acts like a rolling memory of negative samples, where key embeddings are generated by the momentum-updated encoder. The negatives are not from the current batch but from past batches stored in this buffer.</p>
<p>In essence:</p>
<ul>
<li>Positive pairs are <code>q</code> vs. <code>k</code> (current batch).</li>
<li>Negative pairs are <code>q</code> vs. <code>queue</code> (past batches).</li>
</ul>
<p>In the training loop, after obtaining <code>x_q</code> and <code>x_k</code> (two augmented views of the input):</p>
<p>Here, <code>q = query_encoder(x_q)</code> is used for gradient-based learning, while <code>k = key_encoder(x_k)</code> is generated under <code>torch.no_grad()</code> using momentum-updated weights. The <code>logits</code> are formed from <code>[q·k (positives), q·queue (negatives)]</code>. The <code>_momentum_update_key_encoder</code> function ensures the key encoder weights $\theta_k$ smoothly track $\theta_q$, keeping the queue embeddings stable across training steps. This consistency is vital for the negative queue to work effectively with small batches and delayed negatives.</p>
<p>To summarize, MoCo enables contrastive learning without relying on large batch sizes by maintaining a large queue of past key embeddings. These keys are encoded using a momentum-updated encoder, which ensures their consistency over time.</p>
<h3>CPC (Contrastive Predictive Coding)</h3>
<p><b>CPC (Contrastive Predictive Coding)</b> learns representations by dividing signals into segments and using early segments to predict future ones in a contrastive manner. Given a sequence, it first encodes it into latent vectors. Then, a context encoder (often an autoregressive model like a GRU) summarizes the past. Instead of directly reconstructing or regressing to future values, CPC uses a contrastive loss to make the predicted future representation similar to the true future representation while making it dissimilar to negative samples (other possible future representations from the batch). CPC tries to predict the future in a latent space; it doesn't reconstruct the input. The core idea: "Let me summarize the past, then guess what the future will look like in latent space. I'm correct if my prediction is more similar to the true future than to all other possible futures."</p>
<p>Here is CPC step by step:</p>
<ol>
<li>Raw input is a time series, $x = [x_1, x_2, \dots, x_T]$ with shape <code>[B, C, T]</code>.</li>
<li>Pass it through an encoder (often a stack of <code>Conv1d</code> layers) to get latent embeddings $z = \text{encoder}(x)$, with shape <code>[B, T', D]</code>.</li>
<li>Pass latent embeddings through a context encoder (e.g., GRU or Transformer) to get $c_t$, a summary of the past up to time $t$, also with shape <code>[B, T', D]</code>. $c_t$ summarizes $z_{1:t}$.</li>
<li>At each timestep $t$, you try to predict future embeddings. Predict $k$ future steps from each $c_t$. Each future step $k$ can have a learnable linear layer $W_k$ that outputs a predicted future latent $\hat{z}^{t+k}$ for $z_{t+k}$.</li>
<li>Given context $c_t$, the goal is to maximize the similarity between the prediction $\hat{z}^{t+k}$ and the true future latent $z_{t+k}$, while minimizing similarity with negative samples (all other latents $z_j$ in the batch that are not $z_{t+k}$). This is done using the **InfoNCE loss**:</li>
</ol>
<p>$$ \text{loss} = -\log\left(\frac{\exp(\text{sim}(c_t, z_{t+k}))}{\sum_j \exp(\text{sim}(c_t, z_j))}\right) $$</p>
<p>The InfoNCE loss maximizes $\text{sim}(c_t, z_{t+1})$ and minimizes $\text{sim}(c_t, z_j)$ for negative samples $z_j$. Negative samples are typically other latent vectors from the batch that do not correspond to the true future target. CPC learns representations that understand temporal structure.</p>
<p>Here’s an example of the CPC encoder:</p>
<pre><code class="language-python">
import torch.nn as nn
class CPCEncoder(nn.Module):
def __init__(self, in_channels=1, latent_dim=128):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv1d(in_channels, 64, kernel_size=10, stride=5, padding=3), # Downsample
nn.ReLU(),
nn.Conv1d(64, 128, kernel_size=8, stride=4, padding=2), # Downsample more
nn.ReLU(),
nn.Conv1d(128, latent_dim, kernel_size=4, stride=2, padding=1), # Final downsample to latent_dim
nn.ReLU(),
)
def forward(self, x):
"""
x: [B, C, T] - Input time series
output: [B, T', D] - Latent embeddings, T' is reduced time length, D is latent_dim
"""
z = self.encoder(x) # [B, D_latent, T'] (Conv1d outputs [B, out_channels, out_length])
z = z.permute(0, 2, 1) # [B, T', D_latent] - Permute to (Batch, SequenceLength, FeatureDim) for GRU
return z
</code></pre>
<p>And here’s the CPC context encoder. <code>C_t</code> (e.g., <code>c[:, t, :]</code>) summarizes everything prior to time $t$.</p>
<pre><code class="language-python">
class CPCContext(nn.Module):
def __init__(self, input_dim=128, hidden_dim=128):
super().__init__()
self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True) # GRU processes sequence
def forward(self, z):
"""
z: [B, T', D] - Latent embeddings from CPCEncoder
c: [B, T', D] - Contextualized embeddings at each time step
"""
c, _ = self.gru(z) # c: context at each time step from GRU
return c
</code></pre>
<p>We feed the context vector at time $t$ to the prediction head to predict the latent embeddings for future time steps. Specifically, we predict $\hat{z}^{t+k} = W_k(c_t)$, where $W_k$ is a learnable linear layer for each future step $k$. We take the first $T'-k$ time steps from the context (where $T'$ is the sequence length of the latent embeddings), feed them into the linear layer $W_k$, and output the predicted future latent $\hat{z}^{t+k}$. Each future offset has its own predictor. For example, $W_1$ predicts $z_{t+1}$ from $c_t$, and $W_2$ predicts $z_{t+2}$ from $c_t$, etc. Using separate $W_k$ for each $k$ allows specialized learning: $W_1$ might specialize in short-term predictions, while $W_5$ can learn coarser, long-range structures, as the relationship between $c_t$ and $z_{t+1}$ is not necessarily the same as between $c_t$ and $z_{t+5}$.</p>
<pre><code class="language-python">
class CPCPredictor(nn.Module):
def __init__(self, latent_dim=128, k_steps=3):
super().__init__()
self.k_steps = k_steps
self.predictors = nn.ModuleList([
nn.Linear(latent_dim, latent_dim) for _ in range(k_steps) # A linear layer for each future step k
])
def forward(self, context):
"""
context: [B, T', D] (output of GRU context encoder)
returns: list of predictions:
preds[k-1]: [B, T' - k, D], prediction of z_{t+k} from c_t (using context up to T'-k)
"""
B, T_prime, D = context.shape
preds = []
for k, predictor in enumerate(self.predictors, start=1):
# For predicting z_{t+k}, we use context up to T' - k
# e.g., if k=1, use context up to T'-1 to predict z_T'
# if k=2, use context up to T'-2 to predict z_{T'-1}, etc.
pred = predictor(context[:, :T_prime - k, :])
preds.append(pred)
return preds # list of length k_steps, each item is [B, T'-k, D]
</code></pre>
<p>So far, we've encoded signals, used a context encoder to summarize the past, and separate linear layers to predict the future. Next, we use the InfoNCE loss between predicted and actual future embeddings. At each time step $t$ and for each future time step $k$, we want to score the true future latent $z_{t+k}$ as similar to the predicted $\hat{z}^{t+k}$ from context $c_t$, while scoring all other latents as negative. This maximizes similarity with the true future while minimizing similarity with negatives.</p>
<p>The <code>targets</code> passed to InfoNCE are all positive latent vectors (the true future embeddings we want to predict). Negatives are implicitly included within the similarity matrix. We compute an $N \times N$ similarity matrix between every predicted embedding and every target embedding in the batch. The diagonal entries correspond to positive pairs, and off-diagonal entries correspond to negative pairs. The cross-entropy loss treats this as a classification task, where each prediction should select its positive target from all candidates. Labels simply tell the cross-entropy for each row which column is positive. This inherently encourages diagonal elements to be the highest in that row and off-diagonal elements to be pushed down. Instead of a binary classification, the loss is a multi-class classification problem.</p>
<pre><code class="language-python">
def info_nce_loss(preds, targets, temperature=0.07):
"""
Computes InfoNCE loss.
preds: [N, D] predicted latents (flattened batch and time for all predicted steps/batches)
targets: [N, D] true latents (flattened batch and time for all corresponding true future steps/batches)
"""
# Normalize embeddings to unit vectors for cosine similarity
preds_norm = F.normalize(preds, dim=1) # [N, D]
targets_norm = F.normalize(targets, dim=1) # [N, D]
# Compute similarity matrix: (N_preds, N_targets)
# The diagonal elements are positive pairs, off-diagonals are negatives
similarity_matrix = torch.matmul(preds_norm, targets_norm.T) / temperature
# Labels for cross-entropy are the indices of the positive targets
# E.g., for preds[i], its positive target is targets[i]. So label is i.
labels = torch.arange(preds.shape[0], device=preds.device)
# Cross-entropy loss: maximizes the log-likelihood of correctly classifying
# the true positive target among all other targets in the batch.
loss = F.cross_entropy(similarity_matrix, labels)
return loss
</code></pre>
<p>CPC is known to be robust to noise and irrelevant signal parts. After pre-training in a self-supervised fashion, the learned representations can be applied to downstream tasks like classification, regression, and clustering. Typically, the encoder is frozen, and a new classification head is trained on top for the specific downstream task.</p>
<hr>
<h2>BYOL (Bootstrap Your Own Latent)</h2>
<p>Before discussing BYOL, let's clarify some terminology:</p>
<ul>
<li><b>Bootstrapping</b>: In statistics, this refers to resampling a dataset with replacement to create multiple subsets, training models on these subsets, and then averaging or voting on their predictions. This typically reduces variance through ensemble averaging and helps avoid overfitting (e.g., Random Forest bootstraps data for each tree and aggregates predictions).</li>
<li><b>Boosting</b>: This involves training models sequentially, where each subsequent model focuses on correcting the mistakes made by previous models. The final prediction is a weighted sum of all models. This reduces bias as each model learns to refine the previous one's output (e.g., AdaBoost, Gradient Boosting, XGBoost). AdaBoost, for instance, re-weights samples that were misclassified, combining weak learners like decision stumps.</li>
</ul>
<p><b>BYOL (Bootstrap Your Own Latent)</b> is a self-supervised learning method that learns from positive pairs only, without requiring negative samples or explicit labels. The model effectively creates its own training signal. It uses two networks:</p>
<ol>
<li>An **online network**: This network is trainable via standard gradient descent.</li>
<li>A **target network**: This network is a slowly moving average (Exponential Moving Average, EMA) of the online network's parameters and is kept frozen during the forward pass.</li>
</ol>
<p>The online network is trained to predict the output representation of the target network. In essence, the model improves itself by predicting its own slowly updated version. It learns to map different augmented views of the same data to the same latent space without external supervision. This is the "bootstrapping" aspect: the model pulls itself up by comparing itself to its own evolving past outputs. It's like teaching a model to predict its own representation of a different view of the same input, using a stable, slowly updated target network as the prediction target.</p>
<p>The task is to predict the target projection of one augmented view from the online projection of another augmented view. Two augmented views of the same data ($x_1$ and $x_2$) are generated. Both go through their respective encoders. The online path includes an encoder, a projector, and a predictor. The target path includes an encoder and a projector (no predictor). Cosine similarity is computed between the online prediction and the target representation. Backpropagation occurs through the online network, and the target network's weights are updated using EMA.</p>
<p>Let's assume we have a base encoder. We'll wrap it with projector and predictor MLPs. These MLPs output embeddings of <code>out_dim</code> size.</p>
<pre><code class="language-python">
class MLPHead(nn.Module):
def __init__(self, in_dim, hidden_dim=4096, out_dim=256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim), # BatchNorm can help stabilize training
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, out_dim)
)
def forward(self, x):
return self.net(x)
</code></pre>
<p>The online and target networks are initially identical. The <code>_update_target_network</code> method performs the EMA update, smoothly tracking the online network's parameters to keep the target network stable.</p>
<pre><code class="language-python">
class BYOL(nn.Module):
def __init__(self, base_encoder_fn, in_dim, hidden_dim=4096, out_dim=256, ema_decay=0.99):
super().__init__()
# Online network: trainable via gradients
self.online_encoder = base_encoder_fn() # base_encoder_fn should be a callable that returns an encoder instance
self.online_projector = MLPHead(in_dim, hidden_dim, out_dim)
self.online_predictor = MLPHead(out_dim, hidden_dim, out_dim) # Predictor adds asymmetry
# Target network: parameters updated via EMA, kept frozen during forward pass
self.target_encoder = base_encoder_fn()
self.target_projector = MLPHead(in_dim, hidden_dim, out_dim)
# Initialize target network parameters to match online network
self._update_target_network(ema=0) # ema=0 means direct copy
self.ema_decay = ema_decay
@torch.no_grad() # Ensure no gradients are computed for target network update
def _update_target_network(self, ema=None):
"""EMA update for target network parameters based on online network parameters."""
for online_param, target_param in zip(
self.online_encoder.parameters(), self.target_encoder.parameters()):
target_param.data = (
ema * target_param.data + (1 - ema) * online_param.data
if ema is not None else online_param.data.clone()
)
for online_param, target_param in zip(
self.online_projector.parameters(), self.target_projector.parameters()):
target_param.data = (
ema * target_param.data + (1 - ema) * online_param.data
if ema is not None else online_param.data.clone()
)
# The forward pass of BYOL model
def forward(self, x1, x2):
# Online network forward pass
# x1 -> encoder -> projector -> predictor -> o1
o1 = self.online_predictor(self.online_projector(self.online_encoder(x1)))
# x2 -> encoder -> projector -> predictor -> o2
o2 = self.online_predictor(self.online_projector(self.online_encoder(x2)))
# Target network forward pass (gradients are stopped)
with torch.no_grad():
# Apply EMA update before getting target embeddings for current step
# This ensures target network is slightly behind online network
self._update_target_network(ema=self.ema_decay) # Update target network here
# x1 -> target_encoder -> target_projector -> t1
t1 = self.target_projector(self.target_encoder(x1))
# x2 -> target_encoder -> target_projector -> t2
t2 = self.target_projector(self.target_encoder(x2))
# Normalize outputs to unit vectors for cosine similarity loss
o1 = F.normalize(o1, dim=-1)
o2 = F.normalize(o2, dim=-1)
t1 = F.normalize(t1, dim=-1)
t2 = F.normalize(t2, dim=-1)
# Symmetric loss function
# Loss is sum of (1 - cosine_similarity(o1, t2)) and (1 - cosine_similarity(o2, t1))
# Where t1 and t2 are detached (no gradients flow to target network)
loss = 2 - 2 * (
(o1 * t2.detach()).sum(dim=-1).mean() + # Dot product = Cosine similarity for normalized vectors
(o2 * t1.detach()).sum(dim=-1).mean()
) / 2 # Average over the two terms
return loss
</code></pre>
<p>Two augmented views of the same input are forwarded through the network. The projector maps the encoder output to a latent space, and the predictor matches this projection to the target dimension. The reason for having a predictor is to introduce asymmetry between the online and target encoders. This asymmetry is crucial, as without it, the model can collapse to a trivial constant output (mode collapse). The predictor gives the online network the flexibility to learn how to align its representations with the target network's stable representations. Specifically, $x_1$ and $x_2$ go through the online network (encoder $\rightarrow$ projector $\rightarrow$ predictor) to produce $o_1$ and $o_2$. Separately, $x_1$ and $x_2$ go through the target network (encoder $\rightarrow$ projector) to produce $t_1$ and $t_2$, with gradients stopped for this path. BYOL uses cosine similarity, so all outputs are normalized to unit vectors. To prevent shortcut learning and ensure the model learns view-invariant features, the loss matches $o_1$ with $t_2$ and $o_2$ with $t_1$ (cross-view prediction).</p>
<p>The training loop would involve taking a batch $X$, augmenting it into $x_1$ and $x_2$, computing the BYOL loss between the online and target views, backpropagating through the online model, and then updating the target network's weights using EMA.</p>
</body>
</html>