-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmain.py
More file actions
97 lines (80 loc) · 4.22 KB
/
main.py
File metadata and controls
97 lines (80 loc) · 4.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import argparse
import random
import torch
import numpy as np
from time import time
import logging
from torch.utils.data import DataLoader
from datasets import EmbDataset,AllEmbDataset
from models.clvae import CLVAE
from trainer import Trainer
def parse_args():
parser = argparse.ArgumentParser(description="Index")
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument('--epochs', type=int, default=5000, help='number of epochs')
parser.add_argument('--batch_size', type=int, default=2048, help='batch size')
parser.add_argument('--num_workers', type=int, default=4, )
parser.add_argument('--eval_step', type=int, default=50, help='eval step')
parser.add_argument('--learner', type=str, default="AdamW", help='optimizer')
parser.add_argument('--lr_scheduler_type', type=str, default="constant", help='scheduler')
parser.add_argument('--warmup_epochs', type=int, default=50, help='warmup epochs')
parser.add_argument("--data_path", type=str,
default="../data/Games/Games.emb-llama-td.npy",
help="Input data path.")
parser.add_argument("--weight_decay", type=float, default=0.0, help='l2 regularization weight')
parser.add_argument("--dropout_prob", type=float, default=0.0, help="dropout ratio")
parser.add_argument("--bn", type=bool, default=False, help="use bn or not")
parser.add_argument("--loss_type", type=str, default="mse", help="loss_type")
parser.add_argument("--kmeans_init", type=bool, default=True, help="use kmeans_init or not")
parser.add_argument("--kmeans_iters", type=int, default=100, help="max kmeans iters")
parser.add_argument('--sk_epsilons', type=float, nargs='+', default=[0.0, 0.0, 0.0], help="sinkhorn epsilons")
parser.add_argument("--sk_iters", type=int, default=50, help="max sinkhorn iters")
parser.add_argument("--device", type=str, default="cuda:0", help="gpu or cpu")
parser.add_argument('--num_emb_list', type=int, nargs='+', default=[256,256,256], help='emb num of every vq')
parser.add_argument('--e_dim', type=int, default=32, help='vq codebook embedding size')
parser.add_argument('--quant_loss_weight', type=float, default=1.0, help='vq quantion loss weight')
parser.add_argument("--beta", type=float, default=0.25, help="Beta for commitment loss")
parser.add_argument('--layers', type=int, nargs='+', default=[2048,1024,512,256,128,64], help='hidden sizes of every layer')
parser.add_argument('--save_limit', type=int, default=5)
parser.add_argument("--ckpt_dir", type=str, default="", help="output directory for model")
return parser.parse_args()
if __name__ == '__main__':
"""fix the random seed"""
seed = 2024
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
args = parse_args()
print("=================================================")
print(args)
print("=================================================")
logging.basicConfig(level=logging.DEBUG)
"""build dataset"""
# data = AllEmbDataset(args.data_path)
data = EmbDataset(args.data_path)
# import pdb;pdb.set_trace()
model = CLVAE(in_dim=data.dim,
num_emb_list=args.num_emb_list,
e_dim=args.e_dim,
layers=args.layers,
dropout_prob=args.dropout_prob,
bn=args.bn,
loss_type=args.loss_type,
quant_loss_weight=args.quant_loss_weight,
beta=args.beta,
kmeans_init=args.kmeans_init,
kmeans_iters=args.kmeans_iters,
sk_epsilons=args.sk_epsilons,
sk_iters=args.sk_iters,
)
# print(model)
data_loader = DataLoader(data,num_workers=args.num_workers,
batch_size=args.batch_size, shuffle=True,
pin_memory=True)
trainer = Trainer(args,model, len(data_loader))
best_loss, best_collision_rate = trainer.fit(data_loader)
print("Best Loss",best_loss)
print("Best Collision Rate", best_collision_rate)