-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain_initial_model.py
More file actions
81 lines (67 loc) · 3.14 KB
/
train_initial_model.py
File metadata and controls
81 lines (67 loc) · 3.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import pandas as pd
import wandb
import os
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from torchvision import transforms
import torch
# load own code
import sys
sys.path.append('../')
from sleeplib.Resnet_15.model import ResNet
from sleeplib.datasets import BonoboDataset , ContinousToSnippetDataset
from sleeplib.montages import CDAC_combine_montage
from sleeplib.transforms import cut_and_jitter,channel_flip, extremes_remover
# this holds all the configuration parameters
from sleeplib.config import Config
import pickle
# define model name and path
model_path = 'your_path/Models/spikenet2'
# load config and show all default parameters
config = Config()
config.print_config()
combine_montage = CDAC_combine_montage()
# load dataset
df = pd.read_csv(config.PATH_LUT_BONOBO,sep=';') # ; -> ,
# add transformations
transform_train = transforms.Compose([cut_and_jitter(windowsize=config.WINDOWSIZE,max_offset=0.1,Fq=config.FQ),
channel_flip(p=0.5),
extremes_remover(signal_max = 2000, signal_min = 2)
])
transform_val = transforms.Compose([cut_and_jitter(windowsize=config.WINDOWSIZE,max_offset=0,Fq=config.FQ),
extremes_remover(signal_max = 2000, signal_min = 2)])#,CDAC_signal_flip(p=0)])
# init datasets
sub_df = df[df['total_votes_received']>2]
train_df = sub_df[sub_df['Mode']=='Train']
val_df = sub_df[sub_df['Mode']=='Val']
# set up dataloaders
Bonobo_train = BonoboDataset(train_df,
config.PATH_FILES_BONOBO,
transform=transform_train,
montage = combine_montage
)
train_dataloader = DataLoader(Bonobo_train, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=os.cpu_count())
Bonobo_val = BonoboDataset(val_df,
config.PATH_FILES_BONOBO,
transform=transform_val,
montage = combine_montage
)
val_dataloader = DataLoader(Bonobo_val, batch_size=config.BATCH_SIZE,shuffle=False,num_workers=os.cpu_count())
# build model
model = ResNet(lr=config.LR,
n_channels=config.N_CHANNELS,
Focal_loss = True # True means loss function will be Focal loss. Otherwise will be BCE loss
)
# create a logger
wandb.init(dir='logging')
wandb_logger = WandbLogger(project='super_awesome_project')
# create callbacks with early stopping and model checkpoint (saves the best model)
callbacks = [EarlyStopping(monitor='val_loss',patience=5),ModelCheckpoint(dirpath=model_path,filename='weights',monitor='val_loss')]
# create trainer, use fast dev run to test the code
trainer = pl.Trainer(devices=1, accelerator="gpu", min_epochs=30,max_epochs=100,logger=wandb_logger,callbacks=callbacks,fast_dev_run=False)
# train the model
trainer.fit(model,train_dataloader,val_dataloader)
wandb.finish()