-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__main__.py
More file actions
57 lines (42 loc) · 1.63 KB
/
__main__.py
File metadata and controls
57 lines (42 loc) · 1.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import logging
import sys
from logging import basicConfig, getLogger
from arg_parse.make_docs import make_docs
from src.args.args import parse_args
from src.args.parsers.enums import Mode, ModelAction, TensorHandler, Device, MiscAction
from src.config.config import Config
from src.log.log_handler.log_handler import LogHandler
from src.trainer.trainer import Trainer
from stopwatch import Stopwatch
basicConfig(level=logging.INFO)
log = getLogger(__name__)
if __name__ == '__main__':
log.info("Starting...")
args_, base_parser = parse_args(*sys.argv[1:])
config = Config.from_args(args_)
# Validation.
if config.device == Device.CUDA and config._th != TensorHandler.TORCH:
raise ValueError("Need torch for CUDA.")
log_handler = LogHandler()
stopwatch = Stopwatch()
stopwatch.start()
try:
if args_.mode == Mode.MODEL:
if args_.action == ModelAction.TRAIN:
trainer = Trainer(config, log_handler)
trainer.train()
elif args_.action == ModelAction.INFER:
trainer = Trainer(config, log_handler)
trainer.infer()
else:
raise ValueError(f"Unknown mode action: {args_.mode}.{args_.action}")
elif args_.mode == Mode.MISC:
if args_.action == MiscAction.MAKE_DOCS:
make_docs(base_parser)
else:
raise ValueError(f"Unknown misc action: {args_.mode}.{args_.action}")
else:
raise ValueError(f"Unknown mode: {args_.mode}")
except KeyboardInterrupt:
pass
log.info(f"Done after {stopwatch.stop():.3f}s.")