-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathvisualize_assign.py
More file actions
94 lines (72 loc) · 2.68 KB
/
visualize_assign.py
File metadata and controls
94 lines (72 loc) · 2.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/env python3
# Copyright (c) Megvii, Inc. and its affiliates.
import os
import sys
import random
import time
import warnings
from loguru import logger
import torch
import torch.backends.cudnn as cudnn
sys.path.append(sys.path[0].split('/tools')[0])
from yolox.exp import Exp, get_exp
from yolox.core import Trainer
from yolox.utils import configure_module, configure_omp
from yolox.tools.train import make_parser
class AssignVisualizer(Trainer):
def __init__(self, exp: Exp, args):
super().__init__(exp, args)
self.batch_cnt = 0
self.vis_dir = os.path.join(self.file_name, "vis")
os.makedirs(self.vis_dir, exist_ok=True)
def train_one_iter(self):
iter_start_time = time.time()
inps, targets = self.prefetcher.next()
inps = inps.to(self.data_type)
targets = targets.to(self.data_type)
targets.requires_grad = False
inps, targets = self.exp.preprocess(inps, targets, self.input_size)
data_end_time = time.time()
with torch.cuda.amp.autocast(enabled=self.amp_training):
path_prefix = os.path.join(self.vis_dir, f"assign_vis_{self.batch_cnt}_")
self.model.visualize(inps, targets, path_prefix)
if self.use_model_ema:
self.ema_model.update(self.model)
iter_end_time = time.time()
self.meter.update(
iter_time=iter_end_time - iter_start_time,
data_time=data_end_time - iter_start_time,
)
self.batch_cnt += 1
if self.batch_cnt >= self.args.max_batch:
sys.exit(0)
def after_train(self):
logger.info("Finish visualize assignment, exit...")
def assign_vis_parser():
parser = make_parser()
parser.add_argument("--max-batch", type=int, default=1, help="max batch of images to visualize")
return parser
@logger.catch
def main(exp: Exp, args):
if exp.seed is not None:
random.seed(exp.seed)
torch.manual_seed(exp.seed)
cudnn.deterministic = True
warnings.warn(
"You have chosen to seed training. This will turn on the CUDNN deterministic setting, "
"which can slow down your training considerably! You may see unexpected behavior "
"when restarting from checkpoints."
)
# set environment variables for distributed training
configure_omp()
cudnn.benchmark = True
visualizer = AssignVisualizer(exp, args)
visualizer.train()
if __name__ == "__main__":
configure_module()
args = assign_vis_parser().parse_args()
exp = get_exp(args.exp_file, args.name)
exp.merge(args.opts)
if not args.experiment_name:
args.experiment_name = exp.exp_name
main(exp, args)