-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_simclr.py
More file actions
213 lines (177 loc) · 7.67 KB
/
train_simclr.py
File metadata and controls
213 lines (177 loc) · 7.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import os
import torch
import argparse
import matplotlib.pyplot as plt
from datetime import datetime
import numpy as np
from sklearn.manifold import TSNE
import pandas as pd
import seaborn as sns
# Import our modules
from data_preprocessing import create_data_loaders, visualize_augmentations
from simclr_model import SimCLRModel
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description='Train SimCLR model for underwater acoustic spectrograms')
# Data parameters
parser.add_argument('--data_dir', type=str, default='/home/ubuntu/data',
help='Directory containing spectrogram images')
parser.add_argument('--output_dir', type=str, default='/home/ubuntu/output',
help='Directory to save outputs')
# Model parameters
parser.add_argument('--base_model', type=str, default='resnet18',
choices=['resnet18', 'resnet34', 'resnet50'],
help='Base model architecture')
parser.add_argument('--pretrained', action='store_true',
help='Use pretrained weights for the base model')
parser.add_argument('--projection_dim', type=int, default=128,
help='Dimension of projection head output')
# Training parameters
parser.add_argument('--batch_size', type=int, default=32,
help='Batch size for training')
parser.add_argument('--epochs', type=int, default=100,
help='Number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0003,
help='Learning rate')
parser.add_argument('--weight_decay', type=float, default=1e-4,
help='Weight decay')
parser.add_argument('--temperature', type=float, default=0.5,
help='Temperature parameter for NT-Xent loss')
parser.add_argument('--num_workers', type=int, default=4,
help='Number of workers for data loading')
# Checkpoint parameters
parser.add_argument('--resume', type=str, default=None,
help='Path to checkpoint to resume from')
parser.add_argument('--checkpoint_interval', type=int, default=10,
help='Save checkpoint every N epochs')
# Visualization parameters
parser.add_argument('--visualize_augmentations', action='store_true',
help='Visualize augmentations before training')
return parser.parse_args()
def setup_output_dir(output_dir):
"""Create output directory with timestamp."""
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
output_dir = os.path.join(output_dir, f'simclr_{timestamp}')
os.makedirs(output_dir, exist_ok=True)
return output_dir
def visualize_features(model, data_loader, output_dir):
"""Extract features and visualize them using t-SNE."""
# Extract features
features = model.extract_features(data_loader)
# Get file names and categories from data loader
file_names = []
categories = []
# This is a simplification - in practice, you'd need to track which files are in each batch
for file_name in os.listdir(data_loader.dataset.dataset.data_dir):
if not file_name.endswith('.png'):
continue
file_names.append(file_name)
# Extract category from filename
if file_name.startswith('env_noise'):
categories.append('env_noise')
elif 'bio_signal_whale' in file_name:
categories.append('bio_whale')
elif 'bio_signal_fish' in file_name:
categories.append('bio_fish')
elif 'bio_signal_coral' in file_name:
categories.append('bio_coral')
elif 'manmade_boat' in file_name:
categories.append('manmade_boat')
elif 'manmade_ship' in file_name:
categories.append('manmade_ship')
elif 'manmade_submarine' in file_name:
categories.append('manmade_submarine')
elif 'manmade_speedboat' in file_name:
categories.append('manmade_speedboat')
elif 'transient' in file_name:
categories.append('transient')
else:
categories.append('unknown')
# Limit to the number of features we have
file_names = file_names[:features.shape[0]]
categories = categories[:features.shape[0]]
# Apply t-SNE
tsne = TSNE(n_components=2, perplexity=min(30, len(features)-1), random_state=42)
tsne_result = tsne.fit_transform(features)
# Create DataFrame for plotting
df = pd.DataFrame({
'x': tsne_result[:, 0],
'y': tsne_result[:, 1],
'category': categories,
'file': file_names
})
# Plot t-SNE visualization
plt.figure(figsize=(12, 10))
sns.scatterplot(data=df, x='x', y='y', hue='category', palette='viridis', alpha=0.8)
plt.title('t-SNE Visualization of Extracted Features')
plt.savefig(os.path.join(output_dir, 'tsne_features.png'))
plt.close()
return df
def main():
"""Main training function."""
# Parse arguments
args = parse_args()
# Setup output directory
output_dir = setup_output_dir(args.output_dir)
print(f"Outputs will be saved to {output_dir}")
# Save arguments
with open(os.path.join(output_dir, 'args.txt'), 'w') as f:
for arg, value in vars(args).items():
f.write(f"{arg}: {value}\n")
# Visualize augmentations if requested
if args.visualize_augmentations:
print("Visualizing augmentations...")
visualize_augmentations(args.data_dir, num_samples=3)
# Create data loaders
print("Creating data loaders...")
train_loader, val_loader = create_data_loaders(
data_dir=args.data_dir,
batch_size=args.batch_size,
num_workers=args.num_workers,
simclr_mode=True
)
print(f"Training set size: {len(train_loader.dataset)}")
print(f"Validation set size: {len(val_loader.dataset)}")
# Create model configuration
config = {
'base_model': args.base_model,
'pretrained': args.pretrained,
'projection_dim': args.projection_dim,
'batch_size': args.batch_size,
'temperature': args.temperature,
'learning_rate': args.lr,
'weight_decay': args.weight_decay,
'epochs': args.epochs
}
# Create model
print("Creating SimCLR model...")
model = SimCLRModel(config)
# Resume from checkpoint if specified
if args.resume:
print(f"Resuming from checkpoint {args.resume}...")
model.load_checkpoint(args.resume)
# Train model
print("Starting training...")
model.train(train_loader, val_loader, args.epochs)
# Save final model
final_model_path = os.path.join(output_dir, 'final_model.pt')
model.save_model(final_model_path)
# Save final checkpoint
final_checkpoint_path = os.path.join(output_dir, 'final_checkpoint.pt')
model.save_checkpoint(final_checkpoint_path)
# Create evaluation data loader (without SimCLR mode for feature extraction)
eval_loader, _ = create_data_loaders(
data_dir=args.data_dir,
batch_size=args.batch_size,
num_workers=args.num_workers,
simclr_mode=False
)
# Visualize features
print("Visualizing extracted features...")
feature_df = visualize_features(model, eval_loader, output_dir)
# Save feature data
feature_df.to_csv(os.path.join(output_dir, 'feature_visualization.csv'), index=False)
print(f"Training complete. Model saved to {final_model_path}")
print(f"Results saved to {output_dir}")
if __name__ == "__main__":
main()