-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent.py
More file actions
114 lines (96 loc) · 5.06 KB
/
agent.py
File metadata and controls
114 lines (96 loc) · 5.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#######################################################################
# Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) #
# Permission given to modify the code as long as you keep this #
# declaration at the top #
#######################################################################
#######################################################################
# Modifed by 2018 Yongjin Jung (ainvyu@gmail.com) #
# Only import and modify necessary parts for learning #
#######################################################################
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from util import *
def to_np(actions):
return actions.cpu().detach().numpy()
def random_sample(indices, batch_size):
indices = np.asarray(np.random.permutation(indices))
batches = indices[:len(indices) // batch_size * batch_size].reshape(-1, batch_size)
for batch in batches:
yield batch
r = len(indices) % batch_size
if r:
yield indices[-r:]
class PPOAgent:
def __init__(self, config, env, brain_name, task_fn, network, opt, device):
self.config = config
self.env = env
self.task = task_fn
self.network = network
self.opt = opt
self.total_steps = 0
self.online_rewards = np.zeros(config.num_workers)
self.episode_rewards = []
self.states = env.reset(train_mode=True)[brain_name].vector_observations
self.device = device
def step(self):
config = self.config
rollout = []
states = self.states
for _ in range(config.rollout_length):
actions, log_probs, _, values = self.network(states)
next_states, rewards, terminals, _ = self.task.step(to_np(actions))
self.online_rewards += rewards
rewards = config.reward_normalizer(rewards)
for i, terminal in enumerate(terminals):
if terminals[i]:
self.episode_rewards.append(self.online_rewards[i])
self.online_rewards[i] = 0
next_states = config.state_normalizer(next_states)
rollout.append([states, values.detach(), actions.detach(), log_probs.detach(), rewards, 1 - terminals])
states = next_states
self.states = states
pending_value = self.network(states)[-1]
rollout.append([states, pending_value, None, None, None, None])
processed_rollout = [None] * (len(rollout) - 1)
advantages = tensor(np.zeros((config.num_workers, 1)))
returns = pending_value.detach()
for i in reversed(range(len(rollout) - 1)):
states, value, actions, log_probs, rewards, terminals = rollout[i]
terminals = tensor(terminals).unsqueeze(1)
rewards = tensor(rewards).unsqueeze(1)
actions = tensor(actions)
states = tensor(states)
next_value = rollout[i + 1][1]
returns = rewards + config.discount * terminals * returns
if not config.use_gae:
advantages = returns - value.detach()
else:
td_error = rewards + config.discount * terminals * next_value.detach() - value.detach()
advantages = advantages * config.gae_tau * config.discount * terminals + td_error
processed_rollout[i] = [states, actions, log_probs, returns, advantages]
states, actions, log_probs_old, returns, advantages = map(lambda x: torch.cat(x, dim=0), zip(*processed_rollout))
advantages = (advantages - advantages.mean()) / advantages.std()
for _ in range(config.optimization_epochs):
sampler = random_sample(np.arange(states.size(0)), config.mini_batch_size)
for batch_indices in sampler:
batch_indices = tensor(batch_indices).long()
sampled_states = states[batch_indices]
sampled_actions = actions[batch_indices]
sampled_log_probs_old = log_probs_old[batch_indices]
sampled_returns = returns[batch_indices]
sampled_advantages = advantages[batch_indices]
_, log_probs, entropy_loss, values = self.network(sampled_states, sampled_actions)
ratio = (log_probs - sampled_log_probs_old).exp()
obj = ratio * sampled_advantages
obj_clipped = ratio.clamp(1.0 - self.config.ppo_ratio_clip,
1.0 + self.config.ppo_ratio_clip) * sampled_advantages
policy_loss = -torch.min(obj, obj_clipped).mean(0) - config.entropy_weight * entropy_loss.mean()
value_loss = 0.5 * (sampled_returns - values).pow(2).mean()
self.opt.zero_grad()
(policy_loss + value_loss).backward()
nn.utils.clip_grad_norm_(self.network.parameters(), config.gradient_clip)
self.opt.step()
steps = config.rollout_length * config.num_workers
self.total_steps += steps