-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpruned_eval_sensing.py
More file actions
156 lines (127 loc) · 6.24 KB
/
pruned_eval_sensing.py
File metadata and controls
156 lines (127 loc) · 6.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import argparse
from genericpath import exists
from pathlib import Path
import os
import torch
from dataset_classes.csi_sensing import CSISensingDataset
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader
import models_vit
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
import timm
import models_mae
def no_mask_forward(self, imgs, mask_ratio=0.0):
latent, _, ids_restore = self.forward_encoder(imgs, mask_ratio)
out = self.forward_decoder(latent, ids_restore)
# Example: use CLS token only for classification
cls_output = out.mean(dim=1) # shape: [B, decoder_pred_dim]
return cls_output
def forward(self, x):
"""https://github.com/huggingface/pytorch-image-models/blob/054c763fcaa7d241564439ae05fbe919ed85e614/timm/models/vision_transformer.py#L79"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, -1) # original implementation: x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def main(args):
device = args.device
data_dir = Path(args.data_dir)
models = [("vit_small_patch16", "sensing_small_75.pth")]
batch_size = args.batch_size
num_workers = args.num_workers
test_set = CSISensingDataset(data_dir)
test_loader = DataLoader(test_set, batch_size=batch_size, num_workers=num_workers, shuffle=False)
accuracies = np.zeros((len(models),))
conf_matrices = np.zeros((len(models), 6, 6))
model = torch.load(args.model_dir, weights_only=False)
if model.__class__.__name__ == 'MaskedAutoencoderViT':
model.forward = no_mask_forward.__get__(model, models_mae.MaskedAutoencoderViT)
for m in model.modules():
if isinstance(m, timm.models.vision_transformer.Attention):
m.forward = forward.__get__(m, timm.models.vision_transformer.Attention)
with torch.no_grad():
for i, (model_key, model_name) in enumerate(tqdm(models, desc="Models")):
model = model.to(device)
print(model)
all_targets = []
all_preds = []
for k, (images, targets) in enumerate(tqdm(test_loader, desc="Batches", leave=False)):
images = images.to(device)
targets = targets.to(device)
pred = model(images)
pred = torch.argmax(pred, dim=-1)
all_targets.extend(targets.tolist())
all_preds.extend(pred.tolist())
all_targets = np.array(all_targets)
all_preds = np.array(all_preds)
conf_matrices[i] = confusion_matrix(all_targets, all_preds)
accuracies[i] = np.sum(all_targets == all_preds) / len(all_targets)
# Save the accuracy array
np.save(args.output_conf_mats, conf_matrices)
np.save(args.output_accuracy, accuracies)
for i in range(1):
row_sums = np.sum(conf_matrices[i], axis=1)
row_sums[row_sums == 0] = 1
conf_matrices[i] = conf_matrices[i] / row_sums.astype(float)
class_labels = test_set.labels
titles = ['ViT-S75']
plt.rcParams['font.family'] = 'serif'
# Define a custom gridspec for the subplots
fig = plt.figure(figsize=(14, 12)) # Increase figure width
gs = fig.add_gridspec(1, 1, width_ratios=[1]) # Adjust the width ratio of the last column
# Create subplots based on the gridspec
axs = [fig.add_subplot(gs[i, j]) for i in range(1) for j in range(1)]
for i, ax in enumerate(axs):
if (i + 1) % 1 == 0:
sns.heatmap(conf_matrices[i], annot=True, fmt='.2f', cmap='Reds',
xticklabels=class_labels, yticklabels=class_labels, ax=ax,
annot_kws={'size': 10}) # Add colorbar here
else:
sns.heatmap(conf_matrices[i], annot=True, fmt='.2f', cmap='Reds',
xticklabels=class_labels, yticklabels=class_labels, ax=ax,
annot_kws={'size': 10}, cbar=False) # No colorbar for other subplots
ax.set_title(titles[i], fontsize=16)
ax.tick_params(axis='both', labelsize=10)
# Adjust axis labels for the second row, first column (index 3) and third row, second column (index 7)
axs[0].set_ylabel('True label', fontsize=16)
axs[0].set_xlabel('Predicted label', fontsize=16)
# Adjust layout to avoid overlap
plt.tight_layout()
# Save the plot
plt.savefig(args.output_plot, dpi=400)
# Show the plot
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate ViT models on CSI sensing dataset with varying mask ratios.")
parser.add_argument('--model_dir', type=str, default='checkpoints', help='Path of the pruned model')
parser.add_argument('--data_dir', type=str, default='/home/ict317-3/Mohammad/mae/fine-tuning_datasets/NTU-Fi_HAR/test',
help='Path to CSI Sensing dataset directory')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for DataLoader')
parser.add_argument('--num_workers', type=int, default=0, help='Number of workers for DataLoader')
parser.add_argument('--output_plot', type=str, default='/home/ict317-3/Mohammad/mae/has_output_dir/fig_conf_matrices_sensing.png',
help='Path to save the accuracy plot')
parser.add_argument('--output_accuracy', type=str, default='accuracies_sensing.npy',
help='Path to save the accuracy array as a .npy file')
parser.add_argument('--output_conf_mats', type=str, default='conf_mats_sensing.npy',
help='Path to save the accuracy array as a .npy file')
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_args()
main(args)