From b23cc13a4ea1fa8340e4285a2c774a73ff4a6672 Mon Sep 17 00:00:00 2001 From: muupan Date: Sun, 10 Nov 2019 19:52:01 +0900 Subject: [PATCH 1/6] Add weights argument to QuantileDiscreteActionValue --- chainerrl/action_value.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/chainerrl/action_value.py b/chainerrl/action_value.py index 61f1d217f..6db752c51 100644 --- a/chainerrl/action_value.py +++ b/chainerrl/action_value.py @@ -196,12 +196,19 @@ class QuantileDiscreteActionValue(DiscreteActionValue): Args: quantiles (chainer.Variable): (batch_size, n_taus, n_actions) + weights (None or chainer.Variable): (batch_size, n_taus) q_values_formatter (callable): """ - def __init__(self, quantiles, q_values_formatter=lambda x: x): + def __init__( + self, + quantiles, + weights=None, + q_values_formatter=lambda x: x, + ): assert quantiles.ndim == 3 self.quantiles = quantiles + self.weights = weights self.xp = cuda.get_array_module(quantiles.array) self.n_actions = quantiles.shape[2] self.q_values_formatter = q_values_formatter @@ -209,7 +216,10 @@ def __init__(self, quantiles, q_values_formatter=lambda x: x): @cached_property def q_values(self): with chainer.force_backprop_mode(): - return F.mean(self.quantiles, axis=1) + if self.weights is not None: + return F.sum(self.weights[..., None] * self.quantiles, axis=1) + else: + return F.mean(self.quantiles, axis=1) def evaluate_actions_as_quantiles(self, actions): """Return the return quantiles of given actions. From 983470052701f282113d6265087b49985f350822 Mon Sep 17 00:00:00 2001 From: muupan Date: Sun, 10 Nov 2019 19:52:26 +0900 Subject: [PATCH 2/6] Allow additional arguments for StatelessRecurrentChainList --- chainerrl/links/stateless_recurrent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/chainerrl/links/stateless_recurrent.py b/chainerrl/links/stateless_recurrent.py index d41c088ef..8041731dd 100644 --- a/chainerrl/links/stateless_recurrent.py +++ b/chainerrl/links/stateless_recurrent.py @@ -69,7 +69,7 @@ def n_step_forward(self, x, recurrent_state): """ raise NotImplementedError - def __call__(self, x, recurrent_state): + def __call__(self, x, recurrent_state, *args, **kwargs): """One-step batch forward computation. Args: @@ -85,6 +85,8 @@ def __call__(self, x, recurrent_state): split_one_step_batch_input(x), recurrent_state, output_mode='concat', + *args, + **kwargs, ) def mask_recurrent_state_at(self, recurrent_state, indices): From 08551b5e96f6d86a34b3f0827f89d47ba946720b Mon Sep 17 00:00:00 2001 From: muupan Date: Sun, 10 Nov 2019 19:53:00 +0900 Subject: [PATCH 3/6] Implement FQF --- chainerrl/agents/__init__.py | 1 + chainerrl/agents/fqf.py | 350 +++++++++++++++++++++++++++++++++ tests/agents_tests/test_fqf.py | 83 ++++++++ 3 files changed, 434 insertions(+) create mode 100644 chainerrl/agents/fqf.py create mode 100644 tests/agents_tests/test_fqf.py diff --git a/chainerrl/agents/__init__.py b/chainerrl/agents/__init__.py index 99d878935..14de727be 100644 --- a/chainerrl/agents/__init__.py +++ b/chainerrl/agents/__init__.py @@ -10,6 +10,7 @@ from chainerrl.agents.double_pal import DoublePAL # NOQA from chainerrl.agents.dpp import DPP # NOQA from chainerrl.agents.dqn import DQN # NOQA +from chainerrl.agents.fqf import FQF # NOQA from chainerrl.agents.iqn import IQN # NOQA from chainerrl.agents.nsq import NSQ # NOQA from chainerrl.agents.pal import PAL # NOQA diff --git a/chainerrl/agents/fqf.py b/chainerrl/agents/fqf.py new file mode 100644 index 000000000..e8c791d32 --- /dev/null +++ b/chainerrl/agents/fqf.py @@ -0,0 +1,350 @@ +from __future__ import unicode_literals +from __future__ import print_function +from __future__ import division +from __future__ import absolute_import +from builtins import * # NOQA +from future import standard_library +standard_library.install_aliases() # NOQA +import collections + +import chainer +from chainer import cuda +import chainer.functions as F +import numpy as np + +from chainerrl.action_value import QuantileDiscreteActionValue +from chainerrl.agents import dqn +from chainerrl.links import StatelessRecurrentChainList +from chainerrl.agents import iqn + + +def _mean_or_nan(xs): + """Return its mean a non-empty sequence, numpy.nan for a empty one.""" + return np.mean(xs) if xs else np.nan + + +def _evaluate_psi_x_with_quantile_thresholds( + psi_x, phi, f, taus, weights=None): + assert psi_x.ndim == 2 + batch_size, hidden_size = psi_x.shape + assert taus.ndim == 2 + assert taus.shape[0] == batch_size + n_taus = taus.shape[1] + phi_taus = phi(taus) + assert phi_taus.ndim == 3 + assert phi_taus.shape == (batch_size, n_taus, hidden_size) + psi_x_b = F.broadcast_to( + F.expand_dims(psi_x, axis=1), phi_taus.shape) + h = psi_x_b * phi_taus + h = F.reshape(h, (-1, hidden_size)) + assert h.shape == (batch_size * n_taus, hidden_size) + h = f(h) + assert h.ndim == 2 + assert h.shape[0] == batch_size * n_taus + n_actions = h.shape[-1] + h = F.reshape(h, (batch_size, n_taus, n_actions)) + return QuantileDiscreteActionValue(h, weights=weights) + + +def _compute_taus_hat_and_weights(taus): + batch_size, _ = taus.shape + xp = chainer.cuda.get_array_module(taus) + assert xp.allclose(taus[:, -1], xp.ones(batch_size)) + # shifted_tau: [0, tau_0, tau_1, ..., tau_{N-2}] + shifted_tau = xp.concatenate( + [xp.zeros((batch_size, 1), dtype=taus.dtype), taus[:, :-1]], + axis=1, + ) + weights = taus - shifted_tau + taus_hat = (taus + shifted_tau) / 2 + assert taus_hat.shape == taus.shape + assert weights.shape == weights.shape + return taus_hat, weights + + +class FQQFunction(chainer.Chain): + + """Fully-parameterized Quantile Q-function. + + Args: + psi (chainer.Link): Callable link + (batch_size, obs_size) -> (batch_size, hidden_size). + phi (chainer.Link): Callable link + (batch_size, n_taus) -> (batch_size, n_taus, hidden_size). + f (chainer.Link): Callable link + (batch_size * n_taus, hidden_size) + -> (batch_size * n_taus, n_actions). + proposal_net (chainer.Link): Callable link + (batch_size, hidden_size) -> (batch_size, n_taus) + """ + + def __init__(self, psi, phi, f, proposal_net): + super().__init__() + with self.init_scope(): + self.psi = psi + self.phi = phi + self.f = f + self.proposal_net = proposal_net + + def __call__(self, x, taus=None, with_tau_quantiles=False): + """Evaluate given observations. + + Args: + x (ndarray): Batch of observations. + taus (None or ndarray): Taus (Quantile thresholds). If set to None, + the proposal net is used to compute taus. + with_tau_quantiles (bool): If set to True, results with taus are + returned besides ones with taus hat. + + Returns: + QuantileDiscreteActionValue: ActionValue based on tau hat. + ndarray: Tau hat. + QuantileDiscreteActionValue: ActionValue based on tau. + Returned only when with_tau_quantiles=True. + ndarray: Tau. Returned only when with_tau_quantiles=True. + """ + batch_size = x.shape[0] + psi_x = self.psi(x) + assert psi_x.ndim == 2 + assert psi_x.shape[0] == batch_size + if taus is None: + # Make sure errors of the proposal net do not backprop to psi + taus = F.cumsum( + F.softmax(self.proposal_net(psi_x.array), axis=1), axis=1) + + # Quantiles based on tau hat, used to compute Q-values + taus_hat, weights = _compute_taus_hat_and_weights( + _unwrap_variable(taus)) + tau_hat_av = _evaluate_psi_x_with_quantile_thresholds( + psi_x, self.phi, self.f, taus_hat, weights=weights) + + if with_tau_quantiles: + # Quantiles based on tau, used to update the proposal net + tau_av = _evaluate_psi_x_with_quantile_thresholds( + psi_x, self.phi, self.f, _unwrap_variable(taus)) + return tau_hat_av, taus_hat, tau_av, taus + else: + return tau_hat_av, taus_hat + + +class StatelessRecurrentFQQFunction( + StatelessRecurrentChainList): + + """Recurrent Fully-parameterized Quantile Q-function. + + Args: + psi (chainer.Link): Link that implements + `chainerrl.links.StatelessRecurrent`. + (batch_size, obs_size) -> (batch_size, hidden_size). + phi (chainer.Link): Callable link + (batch_size, n_taus) -> (batch_size, n_taus, hidden_size). + f (chainer.Link): Callable link + (batch_size * n_taus, hidden_size) + -> (batch_size * n_taus, n_actions). + proposal_net (chainer.Link): Callable link + (batch_size, hidden_size) -> (batch_size, n_taus) + """ + + def __init__(self, psi, phi, f, proposal_net): + super().__init__(psi, phi, f, proposal_net) + self.psi = psi + self.phi = phi + self.f = f + self.proposal_net = proposal_net + + def n_step_forward( + self, + x, + recurrent_state, + output_mode, + taus=None, + with_tau_quantiles=False, + ): + """Evaluate given observations. + + Args: + x (ndarray): Batch of observations. + Returns: + callable: (batch_size, taus) -> (batch_size, taus, n_actions) + """ + assert output_mode == 'concat' + if recurrent_state is not None: + recurrent_state, = recurrent_state + psi_x, recurrent_state = self.psi.n_step_forward( + x, recurrent_state, output_mode='concat') + assert psi_x.ndim == 2 + + if taus is None: + # Make sure errors of the proposal net do not backprop to psi + taus = F.cumsum( + F.softmax(self.proposal_net(psi_x.array), axis=1), axis=1) + + # Quantiles based on tau hat, used to compute Q-values + taus_hat, weights = _compute_taus_hat_and_weights( + _unwrap_variable(taus)) + tau_hat_av = _evaluate_psi_x_with_quantile_thresholds( + psi_x, self.phi, self.f, taus_hat, weights=weights) + + if with_tau_quantiles: + # Quantiles based on tau, used to update the proposal net + tau_av = _evaluate_psi_x_with_quantile_thresholds( + psi_x, self.phi, self.f, _unwrap_variable(taus)) + return (tau_hat_av, taus_hat, tau_av, taus), (recurrent_state,) + else: + return (tau_hat_av, taus_hat), (recurrent_state,) + + +def _unwrap_variable(x): + if isinstance(x, chainer.Variable): + return x.array + else: + return x + + +class FQF(dqn.DQN): + + """Fully-parameterized Quantile Function (FQF) algorithm. + + See http://arxiv.org/abs/1911.02140. + + Args: + model (FQQFunction): Q-function link to train. + + For other arguments, see chainerrl.agents.DQN. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.proposal_loss_record = collections.deque(maxlen=100) + self.proposal_entropy_record = collections.deque(maxlen=100) + + def _compute_target_values(self, exp_batch, taus): + """Compute a batch of target return distributions. + + Returns: + chainer.Variable: (batch_size, N_prime). + """ + batch_next_state = exp_batch['next_state'] + batch_size = len(exp_batch['reward']) + + if self.recurrent: + (target_next_av, _), _ = self.target_model.n_step_forward( + batch_next_state, + exp_batch['next_recurrent_state'], + output_mode='concat', + taus=taus, + ) + else: + target_next_av, _ = self.target_model(batch_next_state, taus=taus) + greedy_actions = target_next_av.greedy_actions + target_next_maxz = target_next_av.evaluate_actions_as_quantiles( + greedy_actions) + + batch_rewards = exp_batch['reward'] + batch_terminal = exp_batch['is_state_terminal'] + batch_discount = exp_batch['discount'] + assert batch_rewards.shape == (batch_size,) + assert batch_terminal.shape == (batch_size,) + assert batch_discount.shape == (batch_size,) + batch_rewards = F.broadcast_to( + batch_rewards[..., None], target_next_maxz.shape) + batch_terminal = F.broadcast_to( + batch_terminal[..., None], target_next_maxz.shape) + batch_discount = F.broadcast_to( + batch_discount[..., None], target_next_maxz.shape) + + return (batch_rewards + + batch_discount * (1.0 - batch_terminal) * target_next_maxz) + + def _compute_predictions(self, exp_batch): + """Compute a batch of predicted return distributions. + + Returns: + chainer.Variable: Predicted return distributions. + (batch_size, N). + """ + + # Compute Q-values for current states + batch_state = exp_batch['state'] + + # (batch_size, n_actions, n_atoms) + if self.recurrent: + (tau_hat_av, taus_hat, tau_av, taus), _ = self.model.n_step_forward( # NOQA + batch_state, + exp_batch['recurrent_state'], + output_mode='concat', + with_tau_quantiles=True, + ) + else: + tau_hat_av, taus_hat, tau_av, taus = self.model( + batch_state, with_tau_quantiles=True) + batch_actions = exp_batch['action'] + taus_hat_quantiles = tau_hat_av.evaluate_actions_as_quantiles( + batch_actions) + tau_quantiles = tau_av.evaluate_actions_as_quantiles(batch_actions) + + return taus_hat_quantiles, taus_hat, tau_quantiles, taus + + def _compute_loss(self, exp_batch, errors_out=None): + """Compute a loss. + + Returns: + Returns: + chainer.Variable: Scalar loss. + """ + tau_hat_quantiles, taus_hat, tau_quantiles, taus =\ + self._compute_predictions(exp_batch) + with chainer.no_backprop_mode(): + target_quantiles = self._compute_target_values( + exp_batch, taus) + + eltwise_loss = iqn.compute_eltwise_huber_quantile_loss( + tau_hat_quantiles, target_quantiles, taus_hat) + + tau_grad = (2 * tau_quantiles[:, :-1] + - tau_hat_quantiles[:, :-1] + - tau_hat_quantiles[:, 1:]).array + proposal_loss = F.mean(F.sum(tau_grad * taus[:, :-1], axis=1)) + + self.proposal_loss_record.append(float(proposal_loss.array)) + + # Record entropy of proposals + xp = self.xp + probs = taus.array.copy() + probs[:, 1:] -= taus.array[:, :-1] + self.proposal_entropy_record.append( + float(xp.mean(xp.sum(-xp.log(probs + 1e-8), axis=1)))) + + if errors_out is not None: + del errors_out[:] + delta = F.mean(eltwise_loss, axis=(1, 2)) + errors_out.extend(cuda.to_cpu(delta.array)) + + if 'weights' in exp_batch: + return proposal_loss + iqn.compute_weighted_value_loss( + eltwise_loss, exp_batch['weights'], + batch_accumulator=self.batch_accumulator) + else: + return proposal_loss + iqn.compute_value_loss( + eltwise_loss, batch_accumulator=self.batch_accumulator) + + def _evaluate_model_and_update_recurrent_states(self, batch_obs, test): + batch_xs = self.batch_states(batch_obs, self.xp, self.phi) + if self.recurrent: + if test: + (av, _), self.test_recurrent_states = self.model( + batch_xs, recurrent_state=self.test_recurrent_states) + else: + self.train_prev_recurrent_states = self.train_recurrent_states + (av, _), self.train_recurrent_states = self.model( + batch_xs, recurrent_state=self.train_recurrent_states) + else: + av, _ = self.model(batch_xs) + return av + + def get_statistics(self): + return super().get_statistics() + [ + ('average_proposal_loss', _mean_or_nan(self.proposal_loss_record)), + ('average_proposal_entropy', _mean_or_nan( + self.proposal_entropy_record)), + ] diff --git a/tests/agents_tests/test_fqf.py b/tests/agents_tests/test_fqf.py new file mode 100644 index 000000000..fd7fcb337 --- /dev/null +++ b/tests/agents_tests/test_fqf.py @@ -0,0 +1,83 @@ +from __future__ import unicode_literals +from __future__ import print_function +from __future__ import division +from __future__ import absolute_import +from future import standard_library +from builtins import * # NOQA +standard_library.install_aliases() # NOQA + +import chainer +import chainer.functions as F +import chainer.links as L +from chainer import testing + +import basetest_dqn_like as base +from basetest_training import _TestBatchTrainingMixin +import chainerrl +from chainerrl.agents import fqf + + +@testing.parameterize(*testing.product({ + 'N': [2, 32], +})) +class TestFQFOnDiscreteABC( + _TestBatchTrainingMixin, base._TestDQNOnDiscreteABC): + + def make_q_func(self, env): + obs_size = env.observation_space.low.size + hidden_size = 64 + return fqf.FQQFunction( + psi=chainerrl.links.Sequence( + L.Linear(obs_size, hidden_size), + F.relu, + ), + phi=chainerrl.links.Sequence( + chainerrl.agents.iqn.CosineBasisLinear(32, hidden_size), + F.relu, + ), + f=L.Linear(hidden_size, env.action_space.n), + proposal_net=L.Linear( + hidden_size, + self.N, + initialW=chainer.initializers.LeCunNormal(1e-2), + ), + ) + + def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): + return fqf.FQF( + q_func, opt, rbuf, gpu=gpu, gamma=0.9, explorer=explorer, + replay_start_size=100, target_update_interval=100, + ) + + +class TestFQFOnDiscretePOABC( + _TestBatchTrainingMixin, base._TestDQNOnDiscretePOABC): + + def make_q_func(self, env): + obs_size = env.observation_space.low.size + hidden_size = 64 + return fqf.StatelessRecurrentFQQFunction( + psi=chainerrl.links.StatelessRecurrentSequential( + L.Linear(obs_size, hidden_size), + F.relu, + L.NStepRNNTanh(1, hidden_size, hidden_size, 0), + ), + phi=chainerrl.links.Sequence( + chainerrl.agents.iqn.CosineBasisLinear(32, hidden_size), + F.relu, + ), + f=L.Linear(hidden_size, env.action_space.n, + initialW=chainer.initializers.LeCunNormal(1e-1)), + proposal_net=L.Linear( + hidden_size, + 17, + initialW=chainer.initializers.LeCunNormal(1e-2), + ), + ) + + def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): + return fqf.FQF( + q_func, opt, rbuf, gpu=gpu, gamma=0.9, explorer=explorer, + replay_start_size=100, target_update_interval=100, + recurrent=True, + ) From 29347fa18918683ac2d9136ebe273e40536f6216 Mon Sep 17 00:00:00 2001 From: muupan Date: Sun, 10 Nov 2019 20:54:53 +0900 Subject: [PATCH 4/6] Add a script to train FQF on Atari --- examples/atari/reproduction/fqf/train_fqf.py | 209 +++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 examples/atari/reproduction/fqf/train_fqf.py diff --git a/examples/atari/reproduction/fqf/train_fqf.py b/examples/atari/reproduction/fqf/train_fqf.py new file mode 100644 index 000000000..56ac63ee8 --- /dev/null +++ b/examples/atari/reproduction/fqf/train_fqf.py @@ -0,0 +1,209 @@ +from __future__ import print_function +from __future__ import division +from __future__ import unicode_literals +from __future__ import absolute_import +from builtins import * # NOQA +from future import standard_library +standard_library.install_aliases() # NOQA +import argparse +import functools +import json +import os + +import chainer +import chainer.functions as F +import chainer.links as L +import numpy as np + +import chainerrl +from chainerrl import experiments +from chainerrl import explorers +from chainerrl import misc +from chainerrl import replay_buffer +from chainerrl.wrappers import atari_wrappers + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4') + parser.add_argument('--outdir', type=str, default='results', + help='Directory path to save output files.' + ' If it does not exist, it will be created.') + parser.add_argument('--seed', type=int, default=0, + help='Random seed [0, 2 ** 31)') + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--demo', action='store_true', default=False) + parser.add_argument('--load', type=str, default=None) + parser.add_argument('--final-exploration-frames', + type=int, default=10 ** 6) + parser.add_argument('--final-epsilon', type=float, default=0.01) + parser.add_argument('--eval-epsilon', type=float, default=0.001) + parser.add_argument('--steps', type=int, default=5 * 10 ** 7) + parser.add_argument('--max-frames', type=int, + default=30 * 60 * 60, # 30 minutes with 60 fps + help='Maximum number of frames for each episode.') + parser.add_argument('--replay-start-size', type=int, default=5 * 10 ** 4) + parser.add_argument('--target-update-interval', + type=int, default=10 ** 4) + parser.add_argument('--eval-interval', type=int, default=250000) + parser.add_argument('--eval-n-steps', type=int, default=125000) + parser.add_argument('--update-interval', type=int, default=4) + parser.add_argument('--batch-size', type=int, default=32) + parser.add_argument('--logging-level', type=int, default=20, + help='Logging level. 10:DEBUG, 20:INFO etc.') + parser.add_argument('--render', action='store_true', default=False, + help='Render env states in a GUI window.') + parser.add_argument('--monitor', action='store_true', default=False, + help='Monitor env. Videos and additional information' + ' are saved as output files.') + parser.add_argument('--batch-accumulator', type=str, default='mean', + choices=['mean', 'sum']) + parser.add_argument('--quantile-thresholds-N', type=int, default=32) + parser.add_argument('--n-best-episodes', type=int, default=200) + args = parser.parse_args() + + import logging + logging.basicConfig(level=args.logging_level) + + # Set a random seed used in ChainerRL. + misc.set_random_seed(args.seed, gpus=(args.gpu,)) + + # Set different random seeds for train and test envs. + train_seed = args.seed + test_seed = 2 ** 31 - 1 - args.seed + + args.outdir = experiments.prepare_output_dir(args, args.outdir) + print('Output files are saved in {}'.format(args.outdir)) + + def make_env(test): + # Use different random seeds for train and test envs + env_seed = test_seed if test else train_seed + env = atari_wrappers.wrap_deepmind( + atari_wrappers.make_atari(args.env, max_frames=args.max_frames), + episode_life=not test, + clip_rewards=not test) + env.seed(int(env_seed)) + if test: + # Randomize actions like epsilon-greedy in evaluation as well + env = chainerrl.wrappers.RandomizeAction(env, args.eval_epsilon) + if args.monitor: + env = chainerrl.wrappers.Monitor( + env, args.outdir, + mode='evaluation' if test else 'training') + if args.render: + env = chainerrl.wrappers.Render(env) + return env + + env = make_env(test=False) + eval_env = make_env(test=True) + n_actions = env.action_space.n + + q_func = chainerrl.agents.fqf.FQQFunction( + psi=chainerrl.links.Sequence( + L.Convolution2D(None, 32, 8, stride=4), + F.relu, + L.Convolution2D(None, 64, 4, stride=2), + F.relu, + L.Convolution2D(None, 64, 3, stride=1), + F.relu, + functools.partial(F.reshape, shape=(-1, 3136)), + ), + phi=chainerrl.links.Sequence( + chainerrl.agents.iqn.CosineBasisLinear(64, 3136), + F.relu, + ), + f=chainerrl.links.Sequence( + L.Linear(None, 512), + F.relu, + L.Linear(None, n_actions), + ), + proposal_net=L.Linear( + None, + args.quantile_thresholds_N, + initialW=chainer.initializers.LeCunNormal(1e-2), + ), + ) + + # Draw the computational graph and save it in the output directory. + fake_obss = np.zeros((4, 84, 84), dtype=np.float32)[None] + chainerrl.misc.draw_computational_graph( + [q_func(fake_obss)[0]], + os.path.join(args.outdir, 'model')) + + # Use the same hyper parameters as https://arxiv.org/abs/1710.10044 + opt = chainer.optimizers.Adam(5e-5, eps=1e-2 / args.batch_size) + opt.setup(q_func) + + # Lower the learning rate for the proposal network + q_func.proposal_net.W.update_rule.alpha = 1e-5 + q_func.proposal_net.b.update_rule.alpha = 1e-5 + + rbuf = replay_buffer.ReplayBuffer(10 ** 6) + + explorer = explorers.LinearDecayEpsilonGreedy( + 1.0, args.final_epsilon, + args.final_exploration_frames, + lambda: np.random.randint(n_actions)) + + def phi(x): + # Feature extractor + return np.asarray(x, dtype=np.float32) / 255 + + agent = chainerrl.agents.FQF( + q_func, opt, rbuf, gpu=args.gpu, gamma=0.99, + explorer=explorer, replay_start_size=args.replay_start_size, + target_update_interval=args.target_update_interval, + update_interval=args.update_interval, + batch_accumulator=args.batch_accumulator, + phi=phi, + ) + + if args.load: + agent.load(args.load) + + if args.demo: + eval_stats = experiments.eval_performance( + env=eval_env, + agent=agent, + n_steps=args.eval_n_steps, + n_episodes=None, + ) + print('n_steps: {} mean: {} median: {} stdev {}'.format( + args.eval_n_steps, eval_stats['mean'], eval_stats['median'], + eval_stats['stdev'])) + else: + experiments.train_agent_with_evaluation( + agent=agent, + env=env, + steps=args.steps, + eval_n_steps=args.eval_n_steps, + eval_n_episodes=None, + eval_interval=args.eval_interval, + outdir=args.outdir, + save_best_so_far_agent=True, + eval_env=eval_env, + ) + + dir_of_best_network = os.path.join(args.outdir, "best") + agent.load(dir_of_best_network) + + # run 200 evaluation episodes, each capped at 30 mins of play + stats = experiments.evaluator.eval_performance( + env=eval_env, + agent=agent, + n_steps=None, + n_episodes=args.n_best_episodes, + max_episode_len=args.max_frames / 4, + logger=None) + with open(os.path.join(args.outdir, 'bestscores.json'), 'w') as f: + # temporary hack to handle python 2/3 support issues. + # json dumps does not support non-string literal dict keys + json_stats = json.dumps(stats) + print(str(json_stats), file=f) + print("The results of the best scoring network:") + for stat in stats: + print(str(stat) + ":" + str(stats[stat])) + + +if __name__ == '__main__': + main() From 2bd4a79c927810df991f357d255b44a6e9714a71 Mon Sep 17 00:00:00 2001 From: muupan Date: Mon, 11 Nov 2019 15:16:56 +0900 Subject: [PATCH 5/6] Add assert --- chainerrl/action_value.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/chainerrl/action_value.py b/chainerrl/action_value.py index 6db752c51..5d8f188b5 100644 --- a/chainerrl/action_value.py +++ b/chainerrl/action_value.py @@ -209,6 +209,9 @@ def __init__( assert quantiles.ndim == 3 self.quantiles = quantiles self.weights = weights + if weights is not None: + assert weights.ndim == 2 + assert weights.shape == quantiles.shape[:2] self.xp = cuda.get_array_module(quantiles.array) self.n_actions = quantiles.shape[2] self.q_values_formatter = q_values_formatter From f9b24f0bd68449963408cd694bd81c1b221a0c96 Mon Sep 17 00:00:00 2001 From: muupan Date: Mon, 11 Nov 2019 15:17:59 +0900 Subject: [PATCH 6/6] Fix entropy computation and record tau grad norm --- chainerrl/agents/fqf.py | 49 +++++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/chainerrl/agents/fqf.py b/chainerrl/agents/fqf.py index e8c791d32..99eb27989 100644 --- a/chainerrl/agents/fqf.py +++ b/chainerrl/agents/fqf.py @@ -49,7 +49,7 @@ def _evaluate_psi_x_with_quantile_thresholds( def _compute_taus_hat_and_weights(taus): batch_size, _ = taus.shape xp = chainer.cuda.get_array_module(taus) - assert xp.allclose(taus[:, -1], xp.ones(batch_size)) + _assert_taus(taus) # shifted_tau: [0, tau_0, tau_1, ..., tau_{N-2}] shifted_tau = xp.concatenate( [xp.zeros((batch_size, 1), dtype=taus.dtype), taus[:, :-1]], @@ -111,6 +111,7 @@ def __call__(self, x, taus=None, with_tau_quantiles=False): # Make sure errors of the proposal net do not backprop to psi taus = F.cumsum( F.softmax(self.proposal_net(psi_x.array), axis=1), axis=1) + _assert_taus(taus) # Quantiles based on tau hat, used to compute Q-values taus_hat, weights = _compute_taus_hat_and_weights( @@ -119,7 +120,9 @@ def __call__(self, x, taus=None, with_tau_quantiles=False): psi_x, self.phi, self.f, taus_hat, weights=weights) if with_tau_quantiles: - # Quantiles based on tau, used to update the proposal net + # Quantiles based on tau, used to update the proposal net. + # Since we don't compute Q-values based on tau, we don't need to + # specify weights here. tau_av = _evaluate_psi_x_with_quantile_thresholds( psi_x, self.phi, self.f, _unwrap_variable(taus)) return tau_hat_av, taus_hat, tau_av, taus @@ -178,6 +181,7 @@ def n_step_forward( # Make sure errors of the proposal net do not backprop to psi taus = F.cumsum( F.softmax(self.proposal_net(psi_x.array), axis=1), axis=1) + _assert_taus(taus) # Quantiles based on tau hat, used to compute Q-values taus_hat, weights = _compute_taus_hat_and_weights( @@ -187,6 +191,8 @@ def n_step_forward( if with_tau_quantiles: # Quantiles based on tau, used to update the proposal net + # Since we don't compute Q-values based on tau, we don't need to + # specify weights here. tau_av = _evaluate_psi_x_with_quantile_thresholds( psi_x, self.phi, self.f, _unwrap_variable(taus)) return (tau_hat_av, taus_hat, tau_av, taus), (recurrent_state,) @@ -201,6 +207,31 @@ def _unwrap_variable(x): return x +def _assert_taus(taus): + xp = chainer.cuda.get_array_module(taus) + taus = _unwrap_variable(taus) + # all the elements must be less than or equal to 1 + assert xp.all(taus <= 1 + 1e-6), taus + # the last element must be 1 + assert xp.allclose(taus[:, -1], xp.ones(len(taus))), taus + + +def _restore_probs_from_taus(taus): + _assert_taus(taus) + taus = _unwrap_variable(taus) + xp = chainer.cuda.get_array_module(taus) + probs = taus.copy() + probs[:, 1:] -= taus[:, :-1] + assert xp.allclose(probs.sum(axis=1), xp.ones(len(taus))) + return probs + + +def _mean_entropy(probs): + assert probs.ndim == 2 + xp = chainer.cuda.get_array_module(probs) + return -float(xp.mean(xp.sum(probs * xp.log(probs + 1e-8), axis=1))) + + class FQF(dqn.DQN): """Fully-parameterized Quantile Function (FQF) algorithm. @@ -215,7 +246,7 @@ class FQF(dqn.DQN): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.proposal_loss_record = collections.deque(maxlen=100) + self.tau_grad_norm_record = collections.deque(maxlen=100) self.proposal_entropy_record = collections.deque(maxlen=100) def _compute_target_values(self, exp_batch, taus): @@ -304,16 +335,16 @@ def _compute_loss(self, exp_batch, errors_out=None): tau_grad = (2 * tau_quantiles[:, :-1] - tau_hat_quantiles[:, :-1] - tau_hat_quantiles[:, 1:]).array + xp = chainer.cuda.get_array_module(tau_grad) proposal_loss = F.mean(F.sum(tau_grad * taus[:, :-1], axis=1)) - self.proposal_loss_record.append(float(proposal_loss.array)) + # Record norm of \partial W_1 / \partial \tau + tau_grad_norm = xp.mean(xp.linalg.norm(tau_grad, axis=1)) + self.tau_grad_norm_record.append(float(tau_grad_norm)) # Record entropy of proposals - xp = self.xp - probs = taus.array.copy() - probs[:, 1:] -= taus.array[:, :-1] self.proposal_entropy_record.append( - float(xp.mean(xp.sum(-xp.log(probs + 1e-8), axis=1)))) + _mean_entropy(_restore_probs_from_taus(taus))) if errors_out is not None: del errors_out[:] @@ -344,7 +375,7 @@ def _evaluate_model_and_update_recurrent_states(self, batch_obs, test): def get_statistics(self): return super().get_statistics() + [ - ('average_proposal_loss', _mean_or_nan(self.proposal_loss_record)), + ('average_tau_grad_norm', _mean_or_nan(self.tau_grad_norm_record)), ('average_proposal_entropy', _mean_or_nan( self.proposal_entropy_record)), ]