-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
99 lines (85 loc) · 3.07 KB
/
train.py
File metadata and controls
99 lines (85 loc) · 3.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from torch import optim
from model import NeuralSuperSampling
import kornia as K
from loss import WeightedSSIMPerceptualLoss
from dataset import QRISPDataset
from torch.utils.data import DataLoader
class NeuralSuperSamplingPL(L.LightningModule):
def __init__(
self,
scale_factor,
num_frames=5,
weight_scale=10,
lr=1e-4,
perceptual_weight=0.1,
):
super().__init__()
self.lr = lr
self.nss = NeuralSuperSampling(scale_factor, num_frames, weight_scale)
self.loss = WeightedSSIMPerceptualLoss(perceptual_weight)
self.save_hyperparameters()
def training_step(self, batch, batch_idx):
color, motion, depth, y = batch
y_hat = self.nss(color, motion, depth)
loss = self.loss(K.color.rgb_to_ycbcr(y), y_hat)
self.log(
"train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
)
return loss
def validation_step(self, batch, batch_idx):
color, motion, depth, y = batch
y_hat = self.nss(color, motion, depth)
loss = self.loss(K.color.rgb_to_ycbcr(y), y_hat)
self.log(
"val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
)
return loss
def test_step(self, batch, batch_idx):
color, motion, depth, y = batch
y_hat = self.nss(color, motion, depth)
loss = self.loss(K.color.rgb_to_ycbcr(y), y_hat)
self.log("test_loss", loss)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def forward(self, color, motion, depth):
return self.nss(color, motion, depth)
if __name__ == "__main__":
SEQUENCE_LENGTH = 5
BATCH_SIZE = 8
EPOCHS = 10
NUM_WORKERS = 11
train_data = QRISPDataset("data/", split="train", sequence_length=SEQUENCE_LENGTH)
val_data = QRISPDataset("data/", split="val", sequence_length=SEQUENCE_LENGTH)
test_data = QRISPDataset("data/", split="test", sequence_length=SEQUENCE_LENGTH)
train_loader = DataLoader(
train_data,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS,
persistent_workers=True,
)
val_loader = DataLoader(
val_data,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
persistent_workers=True,
)
test_loader = DataLoader(
test_data,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
persistent_workers=True,
)
model = NeuralSuperSamplingPL(
scale_factor=2, num_frames=5, weight_scale=10, lr=1e-4, perceptual_weight=0.1
)
early_stop = EarlyStopping(monitor="val_loss", patience=3, mode="min")
trainer = L.Trainer(max_epochs=EPOCHS, callbacks=[early_stop])
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(model=model, dataloaders=test_loader)