-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathautoencoder_experiments.py
More file actions
62 lines (46 loc) · 1.88 KB
/
autoencoder_experiments.py
File metadata and controls
62 lines (46 loc) · 1.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
import copy
import os
import sys
import torch
import torch.nn as nn
from autoencoder_pruned_finetune_sig_identification import get_args_parser, main
from pruned_eval_radio_sig_identification import evaluate_model
# Load the full pruned autoencoder
pruned_autoencoder = torch.load(
'CFFM/pruning_models_small_ViT/pruned_autoencoder/pruned_model_with_ratio_79.29%.pth',
weights_only=False
)
# Iterate through decoder blocks incrementally
num_blocks = len(pruned_autoencoder.decoder_blocks)
for i in range(num_blocks):
# Create a shallow copy of the original model
model_copy = copy.deepcopy(pruned_autoencoder)
model_copy.decoder_blocks = nn.ModuleList(
list(pruned_autoencoder.decoder_blocks[:i+1])
)
# model_copy.patch_embed.proj = nn.Conv2d(4, 512, kernel_size=(16, 16), stride=(16, 16))
model_copy.decoder_pred = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(256, 20)
)
# Define model save path
experiment_path = 'autoencoder_experiments/radio_sig_id'
# Ensure parent directory exists
if not os.path.exists(experiment_path):
os.mkdir(experiment_path)
model_path = os.path.join(experiment_path,f'pruned_autoencoder_with_{i+1}_blocks.pth')
# Save model
torch.save(model_copy, model_path)
parser = get_args_parser()
args = parser.parse_args([]) # loads all default values
args.lr = 1e-3
args.epochs = 200
args.model_path = model_path
args.output_dir = f"{experiment_path}/finetune_logs/block_{i+1}"
args.log_dir = args.output_dir
main(args)
best_model_path = os.path.join(args.output_dir, 'best_model.pth')
save_path = os.path.join(args.output_dir, f'performance_block_{i+1}.png')
train_performance, test_performance = evaluate_model(best_model_path, save_path)
print(f"Train Performance: {train_performance:.2%}, Test Performance: {test_performance:.2%}")