-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexperiments.py
More file actions
396 lines (316 loc) · 12.8 KB
/
experiments.py
File metadata and controls
396 lines (316 loc) · 12.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
"""
Experiments for OctopusNet paper.
Implements all ablation studies and comparisons described in the paper.
"""
import torch
import json
from datetime import datetime
from tqdm import tqdm
from config import (
OctopusNetConfig,
get_baseline_config,
get_ablation_no_nerve_ring,
get_ablation_no_feedback,
get_gumbel_competition_config,
get_topk_competition_config,
get_multiscale_config,
get_adaptive_kernels_only_config,
get_channel_grouping_config,
)
from octopusnet import OctopusNet
from train import train, evaluate
from data import get_dataloaders
def run_experiment(name, config, num_runs=3):
"""
Run an experiment multiple times and average results.
"""
print(f"\n{'#'*60}")
print(f"# Experiment: {name}")
print(f"{'#'*60}")
results = {
'name': name,
'config': str(config),
'runs': []
}
for run in range(1, num_runs + 1):
print(f"\n--- Run {run}/{num_runs} ---")
torch.manual_seed(run * 42)
model, history = train(config)
best_acc = max(history['test_acc'])
results['runs'].append({
'best_acc': best_acc,
'final_acc': history['test_acc'][-1],
'history': history
})
# Compute averages
avg_best = sum(r['best_acc'] for r in results['runs']) / num_runs
avg_final = sum(r['final_acc'] for r in results['runs']) / num_runs
results['avg_best_acc'] = avg_best
results['avg_final_acc'] = avg_final
print(f"\n{'='*40}")
print(f"Experiment: {name}")
print(f"Average best accuracy: {avg_best:.4f}")
print(f"Average final accuracy: {avg_final:.4f}")
print(f"{'='*40}")
return results
def experiment_A1_num_modules(dataset="cifar10"):
"""A1: Vary number of modules (2, 4, 8, 16)"""
results = []
for n in [2, 4, 8, 16]:
config = OctopusNetConfig(
dataset=dataset,
num_modules=n,
kernel_sizes=[3, 5, 7, 9, 11, 13, 15, 17][:n],
epochs=50,
device="cuda" if torch.cuda.is_available() else "cpu"
)
result = run_experiment(f"A1_modules_{n}", config, num_runs=3)
results.append(result)
return results
def experiment_A2_bottleneck_size(dataset="cifar10"):
"""A2: Vary bottleneck size (8, 16, 32, 64, 128)"""
results = []
for b in [8, 16, 32, 64, 128]:
config = OctopusNetConfig(
dataset=dataset,
bottleneck_size=b,
epochs=50,
device="cuda" if torch.cuda.is_available() else "cpu"
)
result = run_experiment(f"A2_bottleneck_{b}", config, num_runs=3)
results.append(result)
return results
def experiment_A6_resilience(dataset="cifar10"):
"""A6: Test resilience by disabling modules one by one"""
print("\n" + "#"*60)
print("# Experiment A6: Resilience")
print("#"*60)
config = OctopusNetConfig(
dataset=dataset,
epochs=50,
device="cuda" if torch.cuda.is_available() else "cpu"
)
# Train full model first
model, _ = train(config)
_, test_loader = get_dataloaders(config)
results = {
'name': 'A6_resilience',
'full_accuracy': evaluate(model, test_loader, config),
'degradation': []
}
print(f"Full model accuracy: {results['full_accuracy']:.4f}")
# Test with each module disabled
for i in range(model.num_modules):
original = model.simulate_module_failure(i)
acc = evaluate(model, test_loader, config)
model.restore_module(i, original)
results['degradation'].append({
'disabled_module': i,
'accuracy': acc,
'drop': results['full_accuracy'] - acc
})
print(f"Module {i} disabled: accuracy = {acc:.4f} (drop: {results['full_accuracy'] - acc:.4f})")
# Test with multiple modules disabled
for n_disabled in [2, 3]:
if n_disabled >= model.num_modules:
continue
originals = []
for i in range(n_disabled):
originals.append((i, model.simulate_module_failure(i)))
acc = evaluate(model, test_loader, config)
for i, orig in originals:
model.restore_module(i, orig)
results['degradation'].append({
'disabled_modules': list(range(n_disabled)),
'accuracy': acc,
'drop': results['full_accuracy'] - acc
})
print(f"Modules 0-{n_disabled-1} disabled: accuracy = {acc:.4f}")
return results
def experiment_A7_feedback(dataset="cifar10"):
"""A7: Compare with and without feedback"""
results = []
# With feedback (baseline)
config_with = get_baseline_config()
config_with.dataset = dataset
config_with.device = "cuda" if torch.cuda.is_available() else "cpu"
result = run_experiment("A7_with_feedback", config_with, num_runs=3)
results.append(result)
# Without feedback
config_without = get_ablation_no_feedback()
config_without.dataset = dataset
config_without.device = "cuda" if torch.cuda.is_available() else "cpu"
result = run_experiment("A7_without_feedback", config_without, num_runs=3)
results.append(result)
return results
def experiment_A8_nerve_ring(dataset="cifar10"):
"""A8: Compare with and without nerve ring"""
results = []
# With nerve ring (baseline)
config_with = get_baseline_config()
config_with.dataset = dataset
config_with.device = "cuda" if torch.cuda.is_available() else "cpu"
result = run_experiment("A8_with_nerve_ring", config_with, num_runs=3)
results.append(result)
# Without nerve ring
config_without = get_ablation_no_nerve_ring()
config_without.dataset = dataset
config_without.device = "cuda" if torch.cuda.is_available() else "cpu"
result = run_experiment("A8_without_nerve_ring", config_without, num_runs=3)
results.append(result)
return results
def experiment_A9_heterogeneous(dataset="cifar10"):
"""A9: Compare homogeneous vs heterogeneous modules"""
results = []
# Homogeneous (all CNNs)
config_homo = get_baseline_config()
config_homo.dataset = dataset
config_homo.homogeneous = True
config_homo.device = "cuda" if torch.cuda.is_available() else "cpu"
result = run_experiment("A9_homogeneous", config_homo, num_runs=3)
results.append(result)
# Heterogeneous (CNN + Transformer + LSTM)
config_hetero = OctopusNetConfig(
dataset=dataset,
homogeneous=False,
epochs=50,
device="cuda" if torch.cuda.is_available() else "cpu"
)
result = run_experiment("A9_heterogeneous", config_hetero, num_runs=3)
results.append(result)
return results
def experiment_A11_adaptive_threshold(dataset="cifar10"):
"""
A11: Compare fixed threshold vs adaptive threshold for Forward-Forward.
Adaptive threshold automatically adjusts based on the midpoint between
positive and negative goodness values each batch.
"""
results = []
# A11a: Fixed threshold (baseline)
config_fixed = OctopusNetConfig(
dataset=dataset,
ff_threshold=2.0,
ff_adaptive_threshold=False,
epochs=50,
device="cuda" if torch.cuda.is_available() else "cpu"
)
result = run_experiment("A11a_fixed_threshold", config_fixed, num_runs=3)
results.append(result)
# A11b: Adaptive threshold
config_adaptive = OctopusNetConfig(
dataset=dataset,
ff_threshold=2.0, # Initial value, will adapt
ff_adaptive_threshold=True,
epochs=50,
device="cuda" if torch.cuda.is_available() else "cpu"
)
result = run_experiment("A11b_adaptive_threshold", config_adaptive, num_runs=3)
results.append(result)
return results
def experiment_A10_competition_mechanism(dataset="cifar10"):
"""
A10: Compare different competition mechanisms (GWT-inspired)
Tests three competition strategies for the coordinator:
- A10a: Soft attention (standard softmax) - default
- A10b: Gumbel-softmax (hard selection, more faithful to GWT)
- A10c: Top-K sparse attention (only K modules contribute)
"""
results = []
# A10a: Soft attention (baseline)
config_soft = get_baseline_config()
config_soft.dataset = dataset
config_soft.competition_type = "soft"
config_soft.device = "cuda" if torch.cuda.is_available() else "cpu"
result = run_experiment("A10a_soft_attention", config_soft, num_runs=3)
results.append(result)
# A10b: Gumbel-softmax (hard selection)
config_gumbel = get_gumbel_competition_config()
config_gumbel.dataset = dataset
config_gumbel.device = "cuda" if torch.cuda.is_available() else "cpu"
result = run_experiment("A10b_gumbel_hard", config_gumbel, num_runs=3)
results.append(result)
# A10c: Top-2 sparse attention
config_topk = get_topk_competition_config(k=2)
config_topk.dataset = dataset
config_topk.device = "cuda" if torch.cuda.is_available() else "cpu"
result = run_experiment("A10c_topk_2", config_topk, num_runs=3)
results.append(result)
# Additional: Top-1 (winner-take-all)
config_top1 = get_topk_competition_config(k=1)
config_top1.dataset = dataset
config_top1.device = "cuda" if torch.cuda.is_available() else "cpu"
result = run_experiment("A10d_topk_1_winner_take_all", config_top1, num_runs=3)
results.append(result)
return results
def experiment_A12_multiscale(dataset="cifar10"):
"""
A12: Multi-scale input hierarchy (V1/V2/V4/IT cortex analog).
A12a: Adaptive kernels only (no resolution change) — control
A12b: Adaptive kernels + multi-scale input — full bio-inspired hierarchy
"""
results = []
device = "cuda" if torch.cuda.is_available() else "cpu"
# Baseline: original homogeneous config
config_base = get_baseline_config()
config_base.dataset = dataset
config_base.device = device
results.append(run_experiment("A12_baseline", config_base, num_runs=3))
# A12a: kernels adapted per resolution but same input size
config_a = get_adaptive_kernels_only_config(dataset)
config_a.device = device
config_a.epochs = 50
results.append(run_experiment("A12a_adaptive_kernels_only", config_a, num_runs=3))
# A12b: full multi-scale (different resolution per module)
config_b = get_multiscale_config(dataset)
config_b.device = device
config_b.epochs = 50
results.append(run_experiment("A12b_multiscale", config_b, num_runs=3))
return results
def experiment_A15_channel_grouping(dataset="cifar10"):
"""
A15: Channel grouping FF vs standard FF.
Single forward pass — channels split into J groups (one per class).
Goodness computed per group; no x_neg generation needed.
Based on: Ortiz Torres et al. (arXiv:2504.21662, 2025).
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
config_std = get_baseline_config()
config_std.dataset = dataset
config_std.device = device
config_std.epochs = 50
results = [run_experiment("A15_standard_ff", config_std, num_runs=3)]
config_cg = get_channel_grouping_config(dataset)
config_cg.device = device
config_cg.epochs = 50
results.append(run_experiment("A15_channel_grouping", config_cg, num_runs=3))
return results
def run_all_experiments(datasets=["mnist", "cifar10"]):
"""
Run all experiments for the paper.
"""
all_results = {}
for dataset in datasets:
print(f"\n{'*'*70}")
print(f"* Dataset: {dataset}")
print(f"{'*'*70}")
all_results[dataset] = {
'A1_num_modules': experiment_A1_num_modules(dataset),
'A2_bottleneck': experiment_A2_bottleneck_size(dataset),
'A6_resilience': experiment_A6_resilience(dataset),
'A7_feedback': experiment_A7_feedback(dataset),
'A8_nerve_ring': experiment_A8_nerve_ring(dataset),
'A9_heterogeneous': experiment_A9_heterogeneous(dataset),
'A10_competition': experiment_A10_competition_mechanism(dataset),
'A11_adaptive_threshold': experiment_A11_adaptive_threshold(dataset),
'A12_multiscale': experiment_A12_multiscale(dataset),
'A15_channel_grouping': experiment_A15_channel_grouping(dataset),
}
# Save all results
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
with open(f'all_experiments_{timestamp}.json', 'w') as f:
json.dump(all_results, f, indent=2, default=str)
return all_results
if __name__ == "__main__":
# Run a quick test with MNIST
results = run_all_experiments(datasets=["mnist"])
print("\nAll experiments completed!")