-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbaseline_train.py
More file actions
71 lines (55 loc) · 2.7 KB
/
baseline_train.py
File metadata and controls
71 lines (55 loc) · 2.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import comet_ml
import torch
import pytorch_lightning as pl
import randomname
import yaml
import argparse
import os
import platform
from pprint import pprint
from time import strftime
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import CometLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.tuner import Tuner
from dataloader import ProteinDataset
from baseline_model import CNNLSTMPredictor
from utils import check_offline
torch.set_float32_matmul_precision('medium')
def train(config):
start_time = strftime('%Y%m%d-%H%M%S')
experiment_id = f'{start_time}_' + randomname.get_name()
seed_everything(42)
if isinstance(config, str):
print(f'Loading config from {config}')
config = yaml.safe_load(open(config, 'r'))
pprint(config)
elif isinstance(config, dict):
print('Using config dict')
pprint(config)
else:
raise ValueError('config must be either a path to a yaml file or a dict')
model = CNNLSTMPredictor(**config)
train_dataset = ProteinDataset(config=config['data_config'], split='train', tokenizer=model.tokenizer)
model.set_train_dataset(train_dataset)
val_dataset = ProteinDataset(config=config['data_config'], split='val', tokenizer=model.tokenizer)
model.set_val_dataset(val_dataset)
test_dataset = ProteinDataset(config=config['data_config'], split='test', tokenizer=model.tokenizer)
model.set_test_dataset(test_dataset)
os.makedirs(f'logs/{experiment_id}', exist_ok=True)
os.makedirs(f'saved_models_baseline/{experiment_id}', exist_ok=True)
callbacks = []
callbacks.append(ModelCheckpoint(dirpath=f'saved_models_baseline/{experiment_id}', monitor='val_auc', mode='max',
filename='auc_{epoch}', save_last=False, save_top_k=1, verbose=False))
callbacks.append(EarlyStopping(monitor='val_auc', patience=config['train_config']['early_stopping'],
mode='max', verbose=True))
tuner = Tuner(pl.Trainer(accelerator='cuda' if torch.cuda.is_available() else 'cpu', num_sanity_val_steps=0))
tuner.scale_batch_size(model, mode='power', init_val=2, max_trials=1 if platform.system() == 'Darwin' else 9)
model.batch_size //= 2
trainer = pl.Trainer(accelerator='cuda' if torch.cuda.is_available() else 'cpu',
max_epochs=config['train_config']['max_epochs'], log_every_n_steps=1,
callbacks=callbacks, num_sanity_val_steps=0, max_time='01:23:50:00')
trainer.fit(model)
model.load_state_dict(torch.load(trainer.checkpoint_callback.best_model_path)['state_dict'])
trainer.test(model)
return trainer