-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_twitter.py
More file actions
90 lines (65 loc) · 2.24 KB
/
train_twitter.py
File metadata and controls
90 lines (65 loc) · 2.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Default Packages
import os
import os.path as path
import pandas as pd
# Tokenization
from transformers import BertTokenizer
# Internal Packages
from models import SuicideClassifier
from utils import *
from datasets import TwitterDataModule
def read_twitterSI(data_path, n_examples=None):
df = pd.read_csv(data_path)
if n_examples is None:
n_examples = len(df)
df = df.sample(n_examples)
df.reset_index(inplace=True, drop=True)
return df
RUNNING_DIR = r'C:\Code\NLP\ProfileLevel_SI_Classifier'
datasets_dir = path.join(RUNNING_DIR, 'Datasets')
MODEL_SAVE_NAME = 'flour-horse'
def main(hparams, RUNNING_DIR=os.path.dirname(path.realpath(__file__))):
twitter_df = read_twitterSI(
path.join(datasets_dir, 'Origional Suicidal Tweets.csv'),
n_examples=hparams.N_EXAMPLES
)
tokenizer = BertTokenizer.from_pretrained(hparams.BERT_MODEL)
twitter_data_module = TwitterDataModule(
twitter_df,
tokenizer,
splits=[0.7, 0.3],
max_example_len=hparams.MAX_EXAMPLE_LEN,
shuffle=True,
batch_size=hparams.BATCH_SIZE,
)
twitter_training_steps = len(
twitter_df)//hparams.BATCH_SIZE*hparams.NUM_EPOCHS
loaded_model = SuicideClassifier.load_from_checkpoint(
checkpoint_path=path.join(
RUNNING_DIR, 'model_checkpoints',
f'{MODEL_SAVE_NAME}', 'best-checkpoint.ckpt'),
training_steps=twitter_training_steps,
warmup_steps=twitter_training_steps/5,
lr=hparams.LEARNING_RATE,
metrics=['ROC', 'binary_report']
)
twitter_trainer_params = generate_trainer_params(
"TwitterSI Classification",
hparams, RUNNING_DIR
)
twitter_trainer = generate_trainer(twitter_trainer_params)
twitter_trainer.fit(loaded_model, twitter_data_module)
# twitter_trainer.test()
def wandb_sweep():
import wandb
config_defaults = {
'CLASSES': ["suicidal-tweet"]
}
wandb.init(config=config_defaults)
hparams = wandb.config
main(hparams)
if __name__ == '__main__':
RUNNING_DIR = r'C:\Code\NLP\ProfileLevel_SI_Classifier'
hparams = Hyperparameters.from_file(
path.join(RUNNING_DIR, 'twitter_hparams.json'))
main(hparams, RUNNING_DIR=RUNNING_DIR)