-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathuseRay.py
More file actions
executable file
·82 lines (72 loc) · 3 KB
/
useRay.py
File metadata and controls
executable file
·82 lines (72 loc) · 3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/usr/bin/env python3
from ray import train, tune, shutdown, init
import argparse
# import sys
from os import environ
# from time import sleep
# from random import random
from rayConfig import myfunc, gen_params
nb_threads = 6
temp_dir = '/home/best/ray_temp/'
def gen_tune_config(args, param_space, param_bound):
if args.scheduler == 'PBT':
from ray.tune.schedulers import PopulationBasedTraining
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
perturbation_interval=1,
metric="adjusted_score",
mode="max",
hyperparam_mutations=param_space,
# quantile_fraction=0.333,
# resample_probability=0.,
synch=False,
)
tune_config = tune.TuneConfig(num_samples=nb_threads, scheduler=scheduler)
elif args.scheduler == 'PB2':
from ray.tune.schedulers.pb2 import PB2
scheduler = PB2(
time_attr="training_iteration",
perturbation_interval=1,
metric="adjusted_score",
mode="max",
hyperparam_bounds=param_bound,
quantile_fraction=0.333,
synch=True,
)
tune_config = tune.TuneConfig(num_samples=nb_threads, scheduler=scheduler)
elif args.scheduler == 'ASHA':
from ray.tune.schedulers import ASHAScheduler
tune_config=tune.TuneConfig(
metric='adjusted_score',
mode='max',
scheduler=ASHAScheduler(grace_period=1, max_t=4), # regarde la perf à partir de 2 itérations, fin train à 8
num_samples=50 # population de 20 items (au début)
)
else:
raise Exception(f'scheduler {args.scheduler} not known')
return tune_config
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Apply Ray Tune to Alphazero-based training')
parser.add_argument('trial' , action='store' , help='Trial/experiment name')
parser.add_argument('--scheduler' , action='store', default='PBT' , help='PBT or PB2 or ASHA')
parser.add_argument('--synch' , action='store_true' , help='Enable synch option in scheduler')
parser.add_argument('--init-dir' , action='store', required=True , help='Folder where is initial NN')
parser.add_argument('--comp-dir' , action='store', default=None , help='NN to compare each experiment with')
# handle relative dirs
# auto trial name ?
# -m -e -p -d -V -P -ppit
# params to enable/disable
# params space
args = parser.parse_args()
if args.comp_dir is None:
args.comp_dir = args.init_dir
# Disable memory checks
environ["RAY_memory_monitor_refresh_ms"] = "0"
run_config=train.RunConfig(name=args.trial, verbose=0, stop={'training_iteration': 100})
param_init, param_space, param_bound = gen_params(args)
tune_config = gen_tune_config(args, param_space, param_bound)
tuner = tune.Tuner(lambda config: myfunc(config, args, temp_dir), run_config=run_config, tune_config=tune_config, param_space=param_init)
# tuner = tune.Tuner.restore(path='/home/best/ray_results/'+config_name, trainable=myfunc)
shutdown()
init(include_dashboard=False, _temp_dir=temp_dir)
results = tuner.fit()