-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathtrain.py
More file actions
executable file
·46 lines (36 loc) · 1.14 KB
/
train.py
File metadata and controls
executable file
·46 lines (36 loc) · 1.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""Training entrypoint that loads gin configs and launches the trainer."""
import argparse
from accelerate.logging import get_logger
import gin
from trainers import TrainerConfig
logger = get_logger(__name__, log_level="INFO")
def main():
"""Parse gin configs/bindings, build trainer, and start training."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--ginc",
action="append",
help="gin config file",
)
parser.add_argument(
"--ginb",
action="append",
help="gin bindings",
)
args = parser.parse_args()
# Collect inline gin bindings (highest priority)
ginbs = []
if args.ginb:
ginbs.extend(args.ginb)
# Parse gin config files and bindings
gin.parse_config_files_and_bindings(args.ginc, ginbs, finalize_config=True)
trainer_cfg = TrainerConfig()
trainer = trainer_cfg.build()
# Log operative config and persist to experiment folder
conf = gin.operative_config_str()
logger.info(conf)
with open(trainer_cfg.exp_path / "config.gin", "w") as f:
f.write(conf)
trainer.train()
if __name__ == "__main__":
main()