-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuild_task_model.py
More file actions
81 lines (67 loc) · 2.79 KB
/
build_task_model.py
File metadata and controls
81 lines (67 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import argparse
import os
import torch
import torch.nn as nn
from timm.models.vision_transformer import VisionTransformer, Block, PatchEmbed
from functools import partial
def main(args):
# Load pruned model
pruned_vit = torch.load(args.pruned_vit, weights_only=False)
#Extract encoder components
patch_embed = pruned_vit.patch_embed # Patch embedding layer
blocks = nn.Sequential(*list(pruned_vit.blocks)) # Convert ModuleList to Sequential
norm = pruned_vit.norm # Normalization layer
if args.task == 'sensing':
input_channels = 3
output_dim = 6
elif args.task == 'radio':
input_channels = 1
output_dim = 10
elif args.task == '5g':
input_channels = 4
output_dim = 3
# Create a new Vision Transformer model using the extracted encoder
vit_model = VisionTransformer(
img_size=224,
patch_size=16,
embed_dim=512,
depth=12,
num_heads=8,
mlp_ratio=4,
in_chans=input_channels,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
num_classes=512 # Set the desired number of output classes
)
# Replace the relevant components
# vit_model.patch_embed = patch_embed
vit_model.blocks = blocks
vit_model.norm = norm
for module in vit_model.modules():
if isinstance(module, nn.LayerNorm):
module.eps = 1e-06
# Build dynamic task head
head_layers = []
for _ in range(args.head_size - 1):
head_layers += [nn.Linear(512, 1024), nn.GELU(), nn.Dropout(0.1)]
embed_dim = 1024 # Update for next layer
head_layers.append(nn.Linear(embed_dim, output_dim))
vit_model.head = nn.Sequential(*head_layers)
print(f"Model built for task: {args.task} with head size: {args.head_size}")
# Save pruned model
os.makedirs(args.save_dir, exist_ok=True)
model_path = os.path.join(args.save_dir, f'pruned_ViT_for_{args.task}_task.pth')
torch.save(vit_model, model_path)
print(f"\nModel saved to: {model_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Prune the ViT encoder to a specific pruning ratio.")
parser.add_argument('--pruned_vit', type=str, default='.',
help='Path to the pruned ViT')
parser.add_argument('--task', type=str, default='sensing',
help='The task that the model will perform. The tasks are: Human Activity Sensing, Radio Signal Identification, 5G Positioning')
parser.add_argument('--head_size', type=int, default=2,
help='The size of the linyar-layer task head.')
parser.add_argument('--save_dir', type=str, default='.',
help='Directory to save the new model.')
args = parser.parse_args()
main(args)