-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
64 lines (53 loc) · 2.5 KB
/
train.py
File metadata and controls
64 lines (53 loc) · 2.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import pytorch_lightning as pl
import torch
import wandb
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from torch.utils.data import DataLoader
import callbacks
from dataset import TrajectoryDataset, load_data
from nn import KGainModel
if __name__ == '__main__':
wandb_logger = WandbLogger(project="eskf-ship-2", log_model=True)
wandb.config.device = "cuda"
wandb.config.gradient_clip_val = 1
wandb.config.apply_const = True
wandb.config.train_timesteps = 10
wandb.config.recur_hidden_dim = 32
wandb.config.n_recur_layers = 2
wandb.config.recur_dropout = 0
wandb.config.fc_dim = 32
wandb.config.output_dim = 9
wandb.config.fc_dropout = 0
wandb.config.lr = 5e-4
wandb.config.att_coef = 10
wandb.run.log_code('.', exclude_fn=lambda f: 'venv' in f)
tensorboard_logger = TensorBoardLogger("lightning_logs")
torch.manual_seed(0)
data = load_data('simulations/train.pkl')
test_data = load_data('simulations/test.pkl')
train_data, val_data = data.split(0.8)
train_dataset = TrajectoryDataset(train_data)
val_dataset = TrajectoryDataset(val_data)
test_dataset = TrajectoryDataset(test_data)
train_dl = DataLoader(train_dataset, batch_size=1000, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=1000, shuffle=False)
test_dl = DataLoader(test_dataset, batch_size=1000, shuffle=False)
model = KGainModel(9, wandb.config.recur_hidden_dim, wandb.config.n_recur_layers, wandb.config.recur_dropout,
wandb.config.fc_dim, wandb.config.output_dim, wandb.config.fc_dropout, wandb.config.lr,
wandb.config.device, wandb.config.train_timesteps, wandb.config.apply_const,
wandb.config.att_coef)
model.set_beacons(data.beacon_positions)
wandb_logger.watch(model, log="all")
# trainer uses "gpu" instead of "cuda"
dev = wandb.config["device"]
dev = "gpu" if dev == "cuda" else "cpu"
trainer = pl.Trainer(accelerator=dev, max_epochs=2500, logger=[wandb_logger, tensorboard_logger],
callbacks=[callbacks.LinearTimesteps(2, 1000, 100),
ModelCheckpoint(monitor='val_loss', mode='min')],
gradient_clip_val=wandb.config["gradient_clip_val"],
log_every_n_steps=8)
trainer.fit(model, train_dl, val_dl)
# test
test_result = trainer.test(model, test_dl)
print(test_result)