forked from WeiChengTseng/Pytorch-PCGrad
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvisualise.py
More file actions
37 lines (30 loc) · 967 Bytes
/
visualise.py
File metadata and controls
37 lines (30 loc) · 967 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from dataset import MultiMNIST
def main():
global_transformer = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
dst = MultiMNIST(
root="./data",
train=False,
download=False,
transform=global_transformer,
multi=True,
)
loader = torch.utils.data.DataLoader(dst, batch_size=10, shuffle=True, num_workers=4)
for dat in loader:
ims = dat[0].view(10, 28, 28).numpy()
labs_l = dat[1]
labs_r = dat[2]
f, axarr = plt.subplots(2, 5)
for j in range(5):
for i in range(2):
axarr[i][j].imshow(ims[j * 2 + i, :, :], cmap="gray")
axarr[i][j].set_title("{}_{}".format(labs_l[j * 2 + i], labs_r[j * 2 + i]))
plt.show()
plt.close()
break
if __name__ == "__main__":
main()