-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
82 lines (61 loc) · 2.57 KB
/
train.py
File metadata and controls
82 lines (61 loc) · 2.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# -*- coding: utf-8 -*-
"""Untitled10.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1jP0hfA5O4j1AounpfAvq69V5YAC7RT8Z
"""
# train.py
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from models import DeepJSCC
from eval import eval_psnr # eval 함수 재사용
def train_one_epoch(model, loader, opt, device, snr_train=(0, 20)):
model.train()
total = 0.0
fixed = not isinstance(snr_train, (tuple, list))
for x, _ in loader:
x = x.to(device)
if fixed:
snr_db = float(snr_train)
else:
snr_db = float(torch.empty(1).uniform_(snr_train[0], snr_train[1]).item())
xhat = model(x, snr_db)
loss = F.mse_loss(xhat, x)
opt.zero_grad()
loss.backward()
opt.step()
total += loss.item() * x.size(0)
return total / len(loader.dataset)
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print("[device]", device, flush=True)
train_snr_list = [1, 4, 7, 13, 19]
epochs = 5
latent_ch = 8
lr = 1e-3
tfm = transforms.Compose([transforms.ToTensor()])
os.makedirs("./data", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
trainset = datasets.CIFAR10(root="./data", train=True, download=True, transform=tfm)
testset = datasets.CIFAR10(root="./data", train=False, download=True, transform=tfm)
# Windows이면 num_workers=0 추천
train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0,
pin_memory=(device == "cuda"))
test_loader = DataLoader(testset, batch_size=256, shuffle=False, num_workers=0,
pin_memory=(device == "cuda"))
for snr_tr in train_snr_list:
print(f"\n===== Train @ SNR_train={snr_tr} dB =====", flush=True)
model = DeepJSCC(latent_ch=latent_ch).to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(1, epochs + 1):
tr = train_one_epoch(model, train_loader, opt, device, snr_train=snr_tr)
p = eval_psnr(model, test_loader, device, snr_db=snr_tr)
print(f"Epoch {epoch:02d} | train_mse={tr:.6f} | PSNR@{snr_tr}dB={p:.2f}", flush=True)
ckpt_path = f"checkpoints/deepjscc_snrtrain_{snr_tr}dB.pth"
torch.save({"snr_train": snr_tr, "state_dict": model.state_dict()}, ckpt_path)
print("[saved]", ckpt_path, flush=True)
if __name__ == "__main__":
main()