-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathrun.py
More file actions
72 lines (59 loc) · 2.11 KB
/
run.py
File metadata and controls
72 lines (59 loc) · 2.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# -*- coding: utf-8 -*-
# @Author : Hao Fan
# @Time : 2024/6/13
import sys
import torch
from recbole.config import Config
from logging import getLogger
from tim4rec import TiM4Rec
from recbole.data import create_dataset, data_preparation
from recbole.data.transform import construct_transform
from recbole.trainer import Trainer
from recbole.utils import (
init_logger,
init_seed,
set_color,
get_flops,
get_environment,
)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
if __name__ == '__main__':
config = Config(model=TiM4Rec, config_file_list=[f'config/config4beauty_64d.yaml'])
init_seed(config["seed"] + config["local_rank"], config["reproducibility"])
# logger initialization
init_logger(config)
logger = getLogger()
logger.info(sys.argv)
logger.info(config)
# dataset filtering
dataset = create_dataset(config)
logger.info(dataset)
# dataset splitting
train_data, valid_data, test_data = data_preparation(config, dataset)
# model loading and initialization
model = TiM4Rec(config, train_data.dataset)
model = model.to(config['device'])
logger.info(model)
transform = construct_transform(config)
flops = get_flops(model, dataset, config["device"], logger, transform)
logger.info(set_color("FLOPs", "blue") + f": {flops}")
# trainer loading and initialization
trainer = Trainer(config, model)
if config['checkpoint_path'] is not None:
trainer.resume_checkpoint(config['checkpoint_path'])
# model training
best_valid_score, best_valid_result = trainer.fit(
train_data, valid_data, show_progress=config["show_progress"]
)
# model evaluation
test_result = trainer.evaluate(
test_data, show_progress=config["show_progress"]
)
environment_tb = get_environment(config)
logger.info(
"The running environment of this training is as follows:\n"
+ environment_tb.draw()
)
logger.info(set_color("best valid ", "yellow") + f": {best_valid_result}")
logger.info(set_color("test result", "yellow") + f": {test_result}")