From 1501b2458427f279f452e238c578c8dd673a49dc Mon Sep 17 00:00:00 2001 From: muupan Date: Thu, 7 Nov 2019 16:48:27 +0900 Subject: [PATCH 1/4] Make the function top-level thus pickle-able --- chainerrl/misc/async_.py | 9 +++++---- tests/misc_tests/test_async.py | 28 +++++++++++++++++----------- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/chainerrl/misc/async_.py b/chainerrl/misc/async_.py index 317d20877..131e5704f 100644 --- a/chainerrl/misc/async_.py +++ b/chainerrl/misc/async_.py @@ -115,6 +115,11 @@ def share_states_as_shared_arrays(optimizer): return shared_arrays +def set_seed_and_run(process_idx, run_func): + random_seed.set_random_seed(np.random.randint(0, 2 ** 32)) + run_func(process_idx) + + def run_async(n_process, run_func): """Run experiments asynchronously. @@ -125,10 +130,6 @@ def run_async(n_process, run_func): processes = [] - def set_seed_and_run(process_idx, run_func): - random_seed.set_random_seed(np.random.randint(0, 2 ** 32)) - run_func(process_idx) - for process_idx in range(n_process): processes.append(mp.Process(target=set_seed_and_run, args=( process_idx, run_func))) diff --git a/tests/misc_tests/test_async.py b/tests/misc_tests/test_async.py index 6e52de85d..e23effb1f 100644 --- a/tests/misc_tests/test_async.py +++ b/tests/misc_tests/test_async.py @@ -6,6 +6,7 @@ from future import standard_library standard_library.install_aliases() # NOQA +import functools import multiprocessing as mp import os import signal @@ -22,6 +23,20 @@ from chainerrl.misc import async_ +def _increment_counter_x1000(process_idx, counter): + for _ in range(1000): + with counter.get_lock(): + counter.value += 1 + + +def run_with_exit_code_0(process_idx): + sys.exit(0) + + +def run_with_exit_code_11(process_idx): + os.kill(os.getpid(), signal.SIGSEGV) + + class TestAsync(unittest.TestCase): def setUp(self): @@ -185,22 +200,13 @@ def test_shared_link_copy(self): def test_run_async(self): counter = mp.Value('l', 0) - - def run_func(process_idx): - for _ in range(1000): - with counter.get_lock(): - counter.value += 1 + run_func = functools.partial( + _increment_counter_x1000, counter=counter) async_.run_async(4, run_func) self.assertEqual(counter.value, 4000) def test_run_async_exit_code(self): - def run_with_exit_code_0(process_idx): - sys.exit(0) - - def run_with_exit_code_11(process_idx): - os.kill(os.getpid(), signal.SIGSEGV) - with warnings.catch_warnings(record=True) as ws: async_.run_async(4, run_with_exit_code_0) # There should be no AbnormalExitWarning From 95c81977a66bf19f48937621294000d029ac0ca9 Mon Sep 17 00:00:00 2001 From: muupan Date: Thu, 7 Nov 2019 17:17:52 +0900 Subject: [PATCH 2/4] Make run_func pickle-able --- chainerrl/experiments/train_agent_async.py | 111 +++++++++++++-------- 1 file changed, 69 insertions(+), 42 deletions(-) diff --git a/chainerrl/experiments/train_agent_async.py b/chainerrl/experiments/train_agent_async.py index 7189d20d2..46bbcdf14 100644 --- a/chainerrl/experiments/train_agent_async.py +++ b/chainerrl/experiments/train_agent_async.py @@ -6,6 +6,7 @@ from future import standard_library standard_library.install_aliases() # NOQA +import functools import logging import multiprocessing as mp import os @@ -129,6 +130,56 @@ def set_shared_objects(agent, shared_objects): setattr(agent, attr, new_value) +def _run_func( + process_idx, + make_env, + make_agent, + agent, + shared_objects, + profile, + evaluator, + **train_loop_kwargs +): + """This function is run by a training loop process. + + To be pickle-able, this is defined as a top-level function. + """ + random_seed.set_random_seed(process_idx) + + env = make_env(process_idx, test=False) + if evaluator is None: + eval_env = env + else: + eval_env = make_env(process_idx, test=True) + if make_agent is not None: + local_agent = make_agent(process_idx) + set_shared_objects(local_agent, shared_objects) + else: + local_agent = agent + local_agent.process_idx = process_idx + + def f(): + train_loop( + process_idx=process_idx, + agent=local_agent, + env=env, + eval_env=eval_env, + evaluator=evaluator, + **train_loop_kwargs, + ) + + if profile: + import cProfile + cProfile.runctx('f()', globals(), locals(), + 'profile-{}.out'.format(os.getpid())) + else: + f() + + env.close() + if eval_env is not env: + eval_env.close() + + def train_agent_async(outdir, processes, make_env, profile=False, steps=8 * 10 ** 7, @@ -205,48 +256,24 @@ def train_agent_async(outdir, processes, make_env, logger=logger, ) - def run_func(process_idx): - random_seed.set_random_seed(process_idx) - - env = make_env(process_idx, test=False) - if evaluator is None: - eval_env = env - else: - eval_env = make_env(process_idx, test=True) - if make_agent is not None: - local_agent = make_agent(process_idx) - set_shared_objects(local_agent, shared_objects) - else: - local_agent = agent - local_agent.process_idx = process_idx - - def f(): - train_loop( - process_idx=process_idx, - counter=counter, - episodes_counter=episodes_counter, - agent=local_agent, - env=env, - steps=steps, - outdir=outdir, - max_episode_len=max_episode_len, - evaluator=evaluator, - successful_score=successful_score, - training_done=training_done, - eval_env=eval_env, - global_step_hooks=global_step_hooks, - logger=logger) - - if profile: - import cProfile - cProfile.runctx('f()', globals(), locals(), - 'profile-{}.out'.format(os.getpid())) - else: - f() - - env.close() - if eval_env is not env: - eval_env.close() + run_func = functools.partial( + _run_func, + make_agent=make_agent, + agent=agent if make_agent is None else None, + make_env=make_env, + evaluator=evaluator, + shared_objects=shared_objects, + profile=profile, + steps=steps, + outdir=outdir, + counter=counter, + episodes_counter=episodes_counter, + training_done=training_done, + max_episode_len=max_episode_len, + successful_score=successful_score, + logger=logger, + global_step_hooks=global_step_hooks, + ) async_.run_async(processes, run_func) From c4f15d2366ac85b0b499f5443d7fded60cf63565 Mon Sep 17 00:00:00 2001 From: muupan Date: Thu, 7 Nov 2019 17:18:57 +0900 Subject: [PATCH 3/4] Make make_env pickle-able --- examples/gym/train_a3c_gym.py | 39 ++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/examples/gym/train_a3c_gym.py b/examples/gym/train_a3c_gym.py index 199beaff1..da3683b07 100644 --- a/examples/gym/train_a3c_gym.py +++ b/examples/gym/train_a3c_gym.py @@ -17,6 +17,7 @@ from future import standard_library standard_library.install_aliases() # NOQA import argparse +import functools import os # This prevents numpy from using multiple threads @@ -92,6 +93,25 @@ def forward(head, lstm, tail): return pout, vout +def _make_env(process_idx, test, process_seeds, args): + env = gym.make(args.env) + # Use different random seeds for train and test envs + process_seed = int(process_seeds[process_idx]) + env_seed = 2 ** 32 - 1 - process_seed if test else process_seed + env.seed(env_seed) + # Cast observations to float32 because our model uses float32 + env = chainerrl.wrappers.CastObservationToFloat32(env) + if args.monitor and process_idx == 0: + env = chainerrl.wrappers.Monitor(env, args.outdir) + if not test: + # Scale rewards (and thus returns) to a reasonable range so that + # training is easier + env = chainerrl.wrappers.ScaleReward(env, args.reward_scale_factor) + if args.render and process_idx == 0 and not test: + env = chainerrl.wrappers.Render(env) + return env + + def main(): import logging @@ -137,23 +157,8 @@ def main(): args.outdir = experiments.prepare_output_dir(args, args.outdir) - def make_env(process_idx, test): - env = gym.make(args.env) - # Use different random seeds for train and test envs - process_seed = int(process_seeds[process_idx]) - env_seed = 2 ** 32 - 1 - process_seed if test else process_seed - env.seed(env_seed) - # Cast observations to float32 because our model uses float32 - env = chainerrl.wrappers.CastObservationToFloat32(env) - if args.monitor and process_idx == 0: - env = chainerrl.wrappers.Monitor(env, args.outdir) - if not test: - # Scale rewards (and thus returns) to a reasonable range so that - # training is easier - env = chainerrl.wrappers.ScaleReward(env, args.reward_scale_factor) - if args.render and process_idx == 0 and not test: - env = chainerrl.wrappers.Render(env) - return env + make_env = functools.partial( + _make_env, process_seeds=process_seeds, args=args) sample_env = gym.make(args.env) timestep_limit = sample_env.spec.tags.get( From a7ac0f535598f34c3d7f6975c1e243e0ab88c011 Mon Sep 17 00:00:00 2001 From: muupan Date: Thu, 7 Nov 2019 17:21:23 +0900 Subject: [PATCH 4/4] Make make_agent pickle-able --- examples/gym/train_a3c_gym.py | 59 ++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/examples/gym/train_a3c_gym.py b/examples/gym/train_a3c_gym.py index da3683b07..f756de6ab 100644 --- a/examples/gym/train_a3c_gym.py +++ b/examples/gym/train_a3c_gym.py @@ -112,6 +112,31 @@ def _make_env(process_idx, test, process_seeds, args): return env +def _make_agent(process_idx, obs_space, action_space, args): + + # Switch policy types accordingly to action space types + if args.arch == 'LSTMGaussian': + model = A3CLSTMGaussian(obs_space.low.size, action_space.low.size) + elif args.arch == 'FFSoftmax': + model = A3CFFSoftmax(obs_space.low.size, action_space.n) + elif args.arch == 'FFMellowmax': + model = A3CFFMellowmax(obs_space.low.size, action_space.n) + + opt = rmsprop_async.RMSpropAsync( + lr=args.lr, eps=args.rmsprop_epsilon, alpha=0.99) + opt.setup(model) + opt.add_hook(chainer.optimizer.GradientClipping(40)) + if args.weight_decay > 0: + opt.add_hook(NonbiasWeightDecay(args.weight_decay)) + + agent = a3c.A3C(model, opt, t_max=args.t_max, gamma=0.99, + beta=args.beta) + if args.load: + agent.load(args.load) + + return agent + + def main(): import logging @@ -166,40 +191,29 @@ def main(): obs_space = sample_env.observation_space action_space = sample_env.action_space - # Switch policy types accordingly to action space types - if args.arch == 'LSTMGaussian': - model = A3CLSTMGaussian(obs_space.low.size, action_space.low.size) - elif args.arch == 'FFSoftmax': - model = A3CFFSoftmax(obs_space.low.size, action_space.n) - elif args.arch == 'FFMellowmax': - model = A3CFFMellowmax(obs_space.low.size, action_space.n) - - opt = rmsprop_async.RMSpropAsync( - lr=args.lr, eps=args.rmsprop_epsilon, alpha=0.99) - opt.setup(model) - opt.add_hook(chainer.optimizer.GradientClipping(40)) - if args.weight_decay > 0: - opt.add_hook(NonbiasWeightDecay(args.weight_decay)) - - agent = a3c.A3C(model, opt, t_max=args.t_max, gamma=0.99, - beta=args.beta) - if args.load: - agent.load(args.load) + make_agent = functools.partial( + _make_agent, + obs_space=obs_space, + action_space=action_space, + args=args, + ) if args.demo: + agent = make_agent(0) env = make_env(0, True) eval_stats = experiments.eval_performance( env=env, agent=agent, n_steps=None, n_episodes=args.eval_n_runs, - max_episode_len=timestep_limit) + max_episode_len=timestep_limit, + ) print('n_runs: {} mean: {} median: {} stdev {}'.format( args.eval_n_runs, eval_stats['mean'], eval_stats['median'], eval_stats['stdev'])) else: experiments.train_agent_async( - agent=agent, + make_agent=make_agent, outdir=args.outdir, processes=args.processes, make_env=make_env, @@ -208,7 +222,8 @@ def main(): eval_n_steps=None, eval_n_episodes=args.eval_n_runs, eval_interval=args.eval_interval, - max_episode_len=timestep_limit) + max_episode_len=timestep_limit, + ) if __name__ == '__main__':