Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 69 additions & 42 deletions chainerrl/experiments/train_agent_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from future import standard_library
standard_library.install_aliases() # NOQA

import functools
import logging
import multiprocessing as mp
import os
Expand Down Expand Up @@ -129,6 +130,56 @@ def set_shared_objects(agent, shared_objects):
setattr(agent, attr, new_value)


def _run_func(
process_idx,
make_env,
make_agent,
agent,
shared_objects,
profile,
evaluator,
**train_loop_kwargs
):
"""This function is run by a training loop process.

To be pickle-able, this is defined as a top-level function.
"""
random_seed.set_random_seed(process_idx)

env = make_env(process_idx, test=False)
if evaluator is None:
eval_env = env
else:
eval_env = make_env(process_idx, test=True)
if make_agent is not None:
local_agent = make_agent(process_idx)
set_shared_objects(local_agent, shared_objects)
else:
local_agent = agent
local_agent.process_idx = process_idx

def f():
train_loop(
process_idx=process_idx,
agent=local_agent,
env=env,
eval_env=eval_env,
evaluator=evaluator,
**train_loop_kwargs,
)

if profile:
import cProfile
cProfile.runctx('f()', globals(), locals(),
'profile-{}.out'.format(os.getpid()))
else:
f()

env.close()
if eval_env is not env:
eval_env.close()


def train_agent_async(outdir, processes, make_env,
profile=False,
steps=8 * 10 ** 7,
Expand Down Expand Up @@ -205,48 +256,24 @@ def train_agent_async(outdir, processes, make_env,
logger=logger,
)

def run_func(process_idx):
random_seed.set_random_seed(process_idx)

env = make_env(process_idx, test=False)
if evaluator is None:
eval_env = env
else:
eval_env = make_env(process_idx, test=True)
if make_agent is not None:
local_agent = make_agent(process_idx)
set_shared_objects(local_agent, shared_objects)
else:
local_agent = agent
local_agent.process_idx = process_idx

def f():
train_loop(
process_idx=process_idx,
counter=counter,
episodes_counter=episodes_counter,
agent=local_agent,
env=env,
steps=steps,
outdir=outdir,
max_episode_len=max_episode_len,
evaluator=evaluator,
successful_score=successful_score,
training_done=training_done,
eval_env=eval_env,
global_step_hooks=global_step_hooks,
logger=logger)

if profile:
import cProfile
cProfile.runctx('f()', globals(), locals(),
'profile-{}.out'.format(os.getpid()))
else:
f()

env.close()
if eval_env is not env:
eval_env.close()
run_func = functools.partial(
_run_func,
make_agent=make_agent,
agent=agent if make_agent is None else None,
make_env=make_env,
evaluator=evaluator,
shared_objects=shared_objects,
profile=profile,
steps=steps,
outdir=outdir,
counter=counter,
episodes_counter=episodes_counter,
training_done=training_done,
max_episode_len=max_episode_len,
successful_score=successful_score,
logger=logger,
global_step_hooks=global_step_hooks,
)

async_.run_async(processes, run_func)

Expand Down
9 changes: 5 additions & 4 deletions chainerrl/misc/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ def share_states_as_shared_arrays(optimizer):
return shared_arrays


def set_seed_and_run(process_idx, run_func):
random_seed.set_random_seed(np.random.randint(0, 2 ** 32))
run_func(process_idx)


def run_async(n_process, run_func):
"""Run experiments asynchronously.

Expand All @@ -125,10 +130,6 @@ def run_async(n_process, run_func):

processes = []

def set_seed_and_run(process_idx, run_func):
random_seed.set_random_seed(np.random.randint(0, 2 ** 32))
run_func(process_idx)

for process_idx in range(n_process):
processes.append(mp.Process(target=set_seed_and_run, args=(
process_idx, run_func)))
Expand Down
98 changes: 59 additions & 39 deletions examples/gym/train_a3c_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from future import standard_library
standard_library.install_aliases() # NOQA
import argparse
import functools
import os

# This prevents numpy from using multiple threads
Expand Down Expand Up @@ -92,6 +93,50 @@ def forward(head, lstm, tail):
return pout, vout


def _make_env(process_idx, test, process_seeds, args):
env = gym.make(args.env)
# Use different random seeds for train and test envs
process_seed = int(process_seeds[process_idx])
env_seed = 2 ** 32 - 1 - process_seed if test else process_seed
env.seed(env_seed)
# Cast observations to float32 because our model uses float32
env = chainerrl.wrappers.CastObservationToFloat32(env)
if args.monitor and process_idx == 0:
env = chainerrl.wrappers.Monitor(env, args.outdir)
if not test:
# Scale rewards (and thus returns) to a reasonable range so that
# training is easier
env = chainerrl.wrappers.ScaleReward(env, args.reward_scale_factor)
if args.render and process_idx == 0 and not test:
env = chainerrl.wrappers.Render(env)
return env


def _make_agent(process_idx, obs_space, action_space, args):

# Switch policy types accordingly to action space types
if args.arch == 'LSTMGaussian':
model = A3CLSTMGaussian(obs_space.low.size, action_space.low.size)
elif args.arch == 'FFSoftmax':
model = A3CFFSoftmax(obs_space.low.size, action_space.n)
elif args.arch == 'FFMellowmax':
model = A3CFFMellowmax(obs_space.low.size, action_space.n)

opt = rmsprop_async.RMSpropAsync(
lr=args.lr, eps=args.rmsprop_epsilon, alpha=0.99)
opt.setup(model)
opt.add_hook(chainer.optimizer.GradientClipping(40))
if args.weight_decay > 0:
opt.add_hook(NonbiasWeightDecay(args.weight_decay))

agent = a3c.A3C(model, opt, t_max=args.t_max, gamma=0.99,
beta=args.beta)
if args.load:
agent.load(args.load)

return agent


def main():
import logging

Expand Down Expand Up @@ -137,64 +182,38 @@ def main():

args.outdir = experiments.prepare_output_dir(args, args.outdir)

def make_env(process_idx, test):
env = gym.make(args.env)
# Use different random seeds for train and test envs
process_seed = int(process_seeds[process_idx])
env_seed = 2 ** 32 - 1 - process_seed if test else process_seed
env.seed(env_seed)
# Cast observations to float32 because our model uses float32
env = chainerrl.wrappers.CastObservationToFloat32(env)
if args.monitor and process_idx == 0:
env = chainerrl.wrappers.Monitor(env, args.outdir)
if not test:
# Scale rewards (and thus returns) to a reasonable range so that
# training is easier
env = chainerrl.wrappers.ScaleReward(env, args.reward_scale_factor)
if args.render and process_idx == 0 and not test:
env = chainerrl.wrappers.Render(env)
return env
make_env = functools.partial(
_make_env, process_seeds=process_seeds, args=args)

sample_env = gym.make(args.env)
timestep_limit = sample_env.spec.tags.get(
'wrapper_config.TimeLimit.max_episode_steps')
obs_space = sample_env.observation_space
action_space = sample_env.action_space

# Switch policy types accordingly to action space types
if args.arch == 'LSTMGaussian':
model = A3CLSTMGaussian(obs_space.low.size, action_space.low.size)
elif args.arch == 'FFSoftmax':
model = A3CFFSoftmax(obs_space.low.size, action_space.n)
elif args.arch == 'FFMellowmax':
model = A3CFFMellowmax(obs_space.low.size, action_space.n)

opt = rmsprop_async.RMSpropAsync(
lr=args.lr, eps=args.rmsprop_epsilon, alpha=0.99)
opt.setup(model)
opt.add_hook(chainer.optimizer.GradientClipping(40))
if args.weight_decay > 0:
opt.add_hook(NonbiasWeightDecay(args.weight_decay))

agent = a3c.A3C(model, opt, t_max=args.t_max, gamma=0.99,
beta=args.beta)
if args.load:
agent.load(args.load)
make_agent = functools.partial(
_make_agent,
obs_space=obs_space,
action_space=action_space,
args=args,
)

if args.demo:
agent = make_agent(0)
env = make_env(0, True)
eval_stats = experiments.eval_performance(
env=env,
agent=agent,
n_steps=None,
n_episodes=args.eval_n_runs,
max_episode_len=timestep_limit)
max_episode_len=timestep_limit,
)
print('n_runs: {} mean: {} median: {} stdev {}'.format(
args.eval_n_runs, eval_stats['mean'], eval_stats['median'],
eval_stats['stdev']))
else:
experiments.train_agent_async(
agent=agent,
make_agent=make_agent,
outdir=args.outdir,
processes=args.processes,
make_env=make_env,
Expand All @@ -203,7 +222,8 @@ def make_env(process_idx, test):
eval_n_steps=None,
eval_n_episodes=args.eval_n_runs,
eval_interval=args.eval_interval,
max_episode_len=timestep_limit)
max_episode_len=timestep_limit,
)


if __name__ == '__main__':
Expand Down
28 changes: 17 additions & 11 deletions tests/misc_tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from future import standard_library
standard_library.install_aliases() # NOQA

import functools
import multiprocessing as mp
import os
import signal
Expand All @@ -22,6 +23,20 @@
from chainerrl.misc import async_


def _increment_counter_x1000(process_idx, counter):
for _ in range(1000):
with counter.get_lock():
counter.value += 1


def run_with_exit_code_0(process_idx):
sys.exit(0)


def run_with_exit_code_11(process_idx):
os.kill(os.getpid(), signal.SIGSEGV)


class TestAsync(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -185,22 +200,13 @@ def test_shared_link_copy(self):

def test_run_async(self):
counter = mp.Value('l', 0)

def run_func(process_idx):
for _ in range(1000):
with counter.get_lock():
counter.value += 1
run_func = functools.partial(
_increment_counter_x1000, counter=counter)
async_.run_async(4, run_func)
self.assertEqual(counter.value, 4000)

def test_run_async_exit_code(self):

def run_with_exit_code_0(process_idx):
sys.exit(0)

def run_with_exit_code_11(process_idx):
os.kill(os.getpid(), signal.SIGSEGV)

with warnings.catch_warnings(record=True) as ws:
async_.run_async(4, run_with_exit_code_0)
# There should be no AbnormalExitWarning
Expand Down