Skip to content

Commit 92c985f

Browse files
committed
add training notebook
1 parent d67e34b commit 92c985f

1 file changed

Lines changed: 271 additions & 0 deletions

File tree

notebooks/train_a_model.ipynb

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "e53e9fe7",
6+
"metadata": {},
7+
"source": [
8+
"# Model Training Notebook\n",
9+
"\n",
10+
"This notebook provides a simple interface to train different models on the BBBC021 dataset.\n",
11+
"\n",
12+
"## Available Models:\n",
13+
"1. **Vanilla SimCLR** - Standard contrastive learning with data augmentations (optionally use weak labels to prevent compound of positive pair in negative pairs)\n",
14+
"2. **Weak Supervision SimCLR** - Uses compound labels to create positive pairs\n",
15+
"3. **WS-DINO** - Teacher-student distillation approach\n",
16+
"\n",
17+
"## Quick Start:\n",
18+
"1. Set your training parameters in the configuration section (Check out our training module for a more detailed look at what params to set for each training approach)\n",
19+
"2. Choose your model type\n",
20+
"3. Run the training cell"
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"id": "d98f856a",
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"import os\n",
31+
"import sys\n",
32+
"import torch\n",
33+
"import gc\n",
34+
"from pathlib import Path\n",
35+
"\n",
36+
"# Add the parent directory to path so we can import our modules\n",
37+
"sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(''))))\n",
38+
"\n",
39+
"# Import our training functions\n",
40+
"from training.simclr_vanilla_train import train_simclr_vanilla\n",
41+
"from training.simclr_ws_train import train_simclr\n",
42+
"from training.wsdino_resnet_train import train_wsdino\n",
43+
"\n",
44+
"print(\"Available devices:\")\n",
45+
"if torch.cuda.is_available():\n",
46+
" print(f\"CUDA: {torch.cuda.get_device_name(0)}\")\n",
47+
" print(f\"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB\")\n",
48+
"else:\n",
49+
" print(\"CPU only\")\n",
50+
"\n",
51+
"print(f\"PyTorch version: {torch.__version__}\")\n",
52+
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
53+
"if torch.cuda.is_available():\n",
54+
" print(f\"Number of GPUs: {torch.cuda.device_count()}\")\n",
55+
" \n",
56+
"# Clean up any existing GPU memory\n",
57+
"if torch.cuda.is_available():\n",
58+
" torch.cuda.empty_cache()\n",
59+
" gc.collect()"
60+
]
61+
},
62+
{
63+
"cell_type": "markdown",
64+
"id": "207ac59e",
65+
"metadata": {},
66+
"source": [
67+
"## Configuration\n",
68+
"\n",
69+
"Set your training parameters here. You can modify these values based on your computational resources and requirements."
70+
]
71+
},
72+
{
73+
"cell_type": "code",
74+
"execution_count": null,
75+
"id": "b16b515b",
76+
"metadata": {},
77+
"outputs": [],
78+
"source": [
79+
"# TRAINING CONFIGURATION\n",
80+
"\n",
81+
"# Data path - Update this to point to your BBBC021 dataset\n",
82+
"DATA_ROOT = \"/scratch/cv-course2025/group8\"\n",
83+
"\n",
84+
"# Model selection - Choose one of: 'vanilla_simclr', 'ws_simclr', 'wsdino'\n",
85+
"MODEL_TYPE = \"vanilla_simclr\"\n",
86+
"\n",
87+
"# Training parameters\n",
88+
"EPOCHS = 50 # Number of training epochs (reduce for testing)\n",
89+
"BATCH_SIZE = 128 # Batch size (reduce if you get out of memory errors)\n",
90+
"LEARNING_RATE = 0.0003 # Learning rate\n",
91+
"TEMPERATURE = 0.1 # Temperature for contrastive loss\n",
92+
"PROJECTION_DIM = 128 # Projection head output dimension\n",
93+
"\n",
94+
"# Saving options\n",
95+
"SAVE_EVERY = 10 # Save model every N epochs\n",
96+
"SAVE_DIR = \"/scratch/cv-course2025/group8/model_weights\" # Directory to save models\n",
97+
"\n",
98+
"# Advanced options (usually don't need to change)\n",
99+
"COMPOUND_AWARE = True # For vanilla SimCLR: use compound-aware loss\n",
100+
"MOMENTUM = 0.996 # For WS-DINO: teacher momentum\n",
101+
"\n",
102+
"print(\"Training Configuration:\")\n",
103+
"print(f\" Model Type: {MODEL_TYPE}\")\n",
104+
"print(f\" Data Root: {DATA_ROOT}\")\n",
105+
"print(f\" Epochs: {EPOCHS}\")\n",
106+
"print(f\" Batch Size: {BATCH_SIZE}\")\n",
107+
"print(f\" Learning Rate: {LEARNING_RATE}\")\n",
108+
"print(f\" Save Directory: {SAVE_DIR}\")\n",
109+
"\n",
110+
"# Create save directory if it doesn't exist\n",
111+
"os.makedirs(SAVE_DIR, exist_ok=True)\n",
112+
"print(f\" Save directory ready: {os.path.exists(SAVE_DIR)}\")"
113+
]
114+
},
115+
{
116+
"cell_type": "markdown",
117+
"id": "33238a66",
118+
"metadata": {},
119+
"source": [
120+
"## Model Information\n",
121+
"\n",
122+
"Here's a brief overview of each model type:\n",
123+
"\n",
124+
"### 1. Vanilla SimCLR\n",
125+
"- **Method**: Standard contrastive learning with data augmentations\n",
126+
"- **Positive pairs**: Two augmented versions of the same image\n",
127+
"- You can use weak labels to prevent same compounds being ussed in negative pairs here, just use `compound_aware=True`\n",
128+
"\n",
129+
"### 2. Weak Supervision SimCLR (WS-SimCLR)\n",
130+
"- **Method**: Uses compound labels to create positive pairs\n",
131+
"- **Positive pairs**: Two different images from the same compound\n",
132+
"\n",
133+
"### 3. WS-DINO\n",
134+
"- **Method**: Teacher-student distillation with weak supervision\n",
135+
"- **Positive pairs**: Uses compound labels for supervision"
136+
]
137+
},
138+
{
139+
"cell_type": "markdown",
140+
"id": "dc51b7d4",
141+
"metadata": {},
142+
"source": [
143+
"## Training\n",
144+
"\n",
145+
"Run the cell below to start training with your configured parameters."
146+
]
147+
},
148+
{
149+
"cell_type": "code",
150+
"execution_count": null,
151+
"id": "70a6df99",
152+
"metadata": {},
153+
"outputs": [],
154+
"source": [
155+
"# =============================================================================\n",
156+
"# TRAINING EXECUTION\n",
157+
"# =============================================================================\n",
158+
"\n",
159+
"def train_model(model_type, **kwargs):\n",
160+
" \"\"\"\n",
161+
" Train a model based on the specified type and parameters.\n",
162+
" \"\"\"\n",
163+
" # Clear GPU memory before training\n",
164+
" if torch.cuda.is_available():\n",
165+
" torch.cuda.empty_cache()\n",
166+
" gc.collect()\n",
167+
" \n",
168+
" print(f\"Starting training for {model_type}\")\n",
169+
" print(\"=\" * 50)\n",
170+
" \n",
171+
" try:\n",
172+
" if model_type == \"vanilla_simclr\":\n",
173+
" print(\"Training Vanilla SimCLR...\")\n",
174+
" model = train_simclr_vanilla(\n",
175+
" root_path=kwargs['root_path'],\n",
176+
" epochs=kwargs['epochs'],\n",
177+
" batch_size=kwargs['batch_size'],\n",
178+
" learning_rate=kwargs['learning_rate'],\n",
179+
" temperature=kwargs['temperature'],\n",
180+
" projection_dim=kwargs['projection_dim'],\n",
181+
" save_every=kwargs['save_every'],\n",
182+
" save_dir=kwargs['save_dir'],\n",
183+
" compound_aware=kwargs.get('compound_aware', True)\n",
184+
" )\n",
185+
" \n",
186+
" elif model_type == \"ws_simclr\":\n",
187+
" print(\"Training Weak Supervision SimCLR...\")\n",
188+
" model = train_simclr(\n",
189+
" root_path=kwargs['root_path'],\n",
190+
" epochs=kwargs['epochs'],\n",
191+
" batch_size=kwargs['batch_size'],\n",
192+
" learning_rate=kwargs['learning_rate'],\n",
193+
" temperature=kwargs['temperature'],\n",
194+
" projection_dim=kwargs['projection_dim'],\n",
195+
" save_every=kwargs['save_every']\n",
196+
" )\n",
197+
" \n",
198+
" elif model_type == \"wsdino\":\n",
199+
" print(\"Training WS-DINO...\")\n",
200+
" model = train_wsdino(\n",
201+
" root_path=kwargs['root_path'],\n",
202+
" epochs=kwargs['epochs'],\n",
203+
" batch_size=kwargs['batch_size'],\n",
204+
" lr=kwargs['learning_rate'],\n",
205+
" momentum=kwargs.get('momentum', 0.996),\n",
206+
" temperature=kwargs['temperature'],\n",
207+
" save_every=kwargs['save_every']\n",
208+
" )\n",
209+
" \n",
210+
" else:\n",
211+
" raise ValueError(f\"Unknown model type: {model_type}\")\n",
212+
" \n",
213+
" print(\"=\" * 50)\n",
214+
" print(f\"Training completed successfully!\")\n",
215+
" print(f\"Models saved in: {kwargs['save_dir']}\")\n",
216+
" \n",
217+
" return model\n",
218+
" \n",
219+
" except Exception as e:\n",
220+
" print(f\"Training failed with error: {str(e)}\")\n",
221+
" print(\"Please check your configuration and try again.\")\n",
222+
" raise e\n",
223+
"\n",
224+
"# Prepare training parameters\n",
225+
"training_params = {\n",
226+
" 'root_path': DATA_ROOT,\n",
227+
" 'epochs': EPOCHS,\n",
228+
" 'batch_size': BATCH_SIZE,\n",
229+
" 'learning_rate': LEARNING_RATE,\n",
230+
" 'temperature': TEMPERATURE,\n",
231+
" 'projection_dim': PROJECTION_DIM,\n",
232+
" 'save_every': SAVE_EVERY,\n",
233+
" 'save_dir': SAVE_DIR,\n",
234+
" 'compound_aware': COMPOUND_AWARE,\n",
235+
" 'momentum': MOMENTUM\n",
236+
"}\n",
237+
"\n",
238+
"print(\"Training parameters:\")\n",
239+
"for key, value in training_params.items():\n",
240+
" print(f\" {key}: {value}\")\n",
241+
"\n",
242+
"# Start training\n",
243+
"print(f\"\\nStarting training with model type: {MODEL_TYPE}\")\n",
244+
"trained_model = train_model(MODEL_TYPE, **training_params)"
245+
]
246+
},
247+
{
248+
"cell_type": "markdown",
249+
"id": "2867e6fa",
250+
"metadata": {},
251+
"source": [
252+
"## Save your model\n",
253+
"\n",
254+
"depending on your training approach, you will find your model under `/scratch/cv-course2025/group8/model_weights/<training_approach>`. You can then use the extractor and evaluator to see how your model performed. If you think you created a WORTHY model, we recommend giving it a unique and somewhat descriptive name and renaming the folders containing your model/features."
255+
]
256+
},
257+
{
258+
"cell_type": "markdown",
259+
"id": "6bd4782f",
260+
"metadata": {},
261+
"source": []
262+
}
263+
],
264+
"metadata": {
265+
"language_info": {
266+
"name": "python"
267+
}
268+
},
269+
"nbformat": 4,
270+
"nbformat_minor": 5
271+
}

0 commit comments

Comments
 (0)