-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_classifier.py
More file actions
58 lines (48 loc) · 2.54 KB
/
train_classifier.py
File metadata and controls
58 lines (48 loc) · 2.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
Train a topic classifier.
Reads training embeddings and labels, trains a model, and saves the checkpoint containing the model weights and label mapping.
Usage:
python3 train_classifier.py --train_embeddings ../data/train_emb.npz \
--labels ../data/labels.txt \
--output ../models/model.pt \
"""
import torch
from datetime import datetime
from argparse import ArgumentParser
from topic_model import train
def main():
argparser = ArgumentParser(description="Trains a feed-forward topic classifier on sentence embeddings.")
argparser.add_argument("--train_embeddings", type=str, required=True, help="Path to training embeddings .npz file.")
argparser.add_argument("--labels", type=str, required=True, help="Path to labels.txt file.")
argparser.add_argument("--output", type=str, required=True, help="Where to save the trained model (.pt file).")
argparser.add_argument("--epochs", type=int, default=50, help="Number of training epochs (default: 50).")
argparser.add_argument("--batch_size", type=int, default=32, help="Batch size (default: 32).")
argparser.add_argument("--lr", type=float, default=0.001, help="Learning rate for Adam (default: 0.001).")
argparser.add_argument("--hiddensize", type=int, default=128, help="Hidden layer size (default: 128).")
args = argparser.parse_args()
print("="*100)
with open(args.labels, 'r') as f:
labels = [line.strip() for line in f if line.strip()]
label_map = {label: idx for idx, label in enumerate(labels)}
print(f"[{datetime.now()}] Training feed-forward topic classifier...", flush=True)
model = train(
npz_path = args.train_embeddings,
label_map = label_map,
epochs = args.epochs,
batch_size = args.batch_size,
lr = args.lr,
hiddensize = args.hiddensize,
outputsize = len(label_map))
print(f" Parameters: epochs={args.epochs}, batch_size={args.batch_size}, adam_learn_rate={args.lr}, hidden_layers={args.hiddensize}")
torch.save({
'model_state_dict': model.state_dict(),
'label_map': label_map,
'inputsize': model.linear0.in_features,
'hiddensize': args.hiddensize,
'outputsize': len(label_map)
}, args.output)
print(f" Model saved to {args.output}")
print(f"\n[{datetime.now()}] Training succesful!!", flush=True)
print("="*100)
if __name__ == "__main__":
main()