-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAgent.py
More file actions
119 lines (104 loc) · 6.09 KB
/
Agent.py
File metadata and controls
119 lines (104 loc) · 6.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from data import ReplayBuffer
import torch.nn as nn
from torch.distributions import Uniform
import numpy as np
class Agent:
def __init__(self, network, gamma, demo_buffer, actor_lr, critic_lr, n_steps, target_entropy):
self.network = network
self.gamma = gamma
self.alpha = 0.1
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.buffer = ReplayBuffer(demo_buffer, gamma, self.device)
self.mse = nn.MSELoss()
self.actor_optimizer = torch.optim.Adam(self.network.actor.parameters(), actor_lr)
self.critic1_optimizer = torch.optim.Adam(self.network.critic1.parameters(), critic_lr)
self.critic2_optimizer = torch.optim.Adam(self.network.critic2.parameters(), critic_lr)
self.initial_lr_actor = actor_lr
self.initial_lr_critic = critic_lr
self.lr_actor = actor_lr
self.lr_critic = critic_lr
self.log_alpha = torch.zeros(1, requires_grad=True).to(self.device)
self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=actor_lr)
self.target_entropy = target_entropy
self.actor_lr_decay = (self.initial_lr_actor - 1e-7) / n_steps
self.critic_lr_decay = (self.initial_lr_critic - 1e-7) / n_steps
def update(self, batch_size, n_epochs):
policy_losses = []
q1_val_loss = []
q2_val_loss = []
entropies = []
alpha_losses = []
alphas = []
# optionally add entropy scaling loss for SAC and alpha lagrangian for CQL
for e in range(n_epochs):
states, actions, next_states, rewards, dones = self.buffer.sample(batch_size)
new_acts, log_probs, entropy = self.network.act(states)
# compute alpha loss
alpha_loss = -(self.log_alpha * (log_probs.detach() + self.target_entropy)).mean()
self.alpha_optimizer.zero_grad()
alpha_loss.backward()
self.alpha_optimizer.step()
alpha = self.log_alpha.exp()
# update for soft actor critic
q_new_acts = self.network.evaluate(states, new_acts)
policy_loss = (alpha * log_probs - q_new_acts).mean()
# perform critic update
q1_vals = self.network.critic1(states + [actions])
q2_vals = self.network.critic2(states + [actions])
next_actions, _ = self.network.sample_n(next_states, 10)
new_next_states = [torch.stack([next_states[i]]*10).reshape(160, *next_states[i].shape[1:])
for i in range(len(next_states))]
target_q1 = self.network.target1(new_next_states + [next_actions])
target_q1 = target_q1.reshape(next_actions.shape).max(0)[0]
target_q2 = self.network.target2(new_next_states + [next_actions])
target_q2 = target_q2.reshape(next_actions.shape).max(0)[0]
target_q = torch.min(target_q1, target_q2)
done_mask = torch.where(dones == 1.0, 0, 1)
td_target = rewards.unsqueeze(-1) + self.gamma * target_q.detach() * done_mask.unsqueeze(-1)
td1_error = self.mse(q1_vals.squeeze(), td_target)
td2_error = self.mse(q2_vals.squeeze(), td_target)
curr_actions, curr_probs = self.network.sample_n(states, 10)
next_actions, next_probs = self.network.sample_n(next_states, 10)
rand_actions = Uniform(low=-1, high=1).rsample(sample_shape=next_actions.shape)
rand_density = np.log(0.5 ** curr_actions.shape[-1])
new_states = [torch.stack([states[i]] * 10).reshape(160, *states[i].shape[1:])
for i in range(len(states))]
q1_rand = self.network.critic1(new_states + [rand_actions])
q2_rand = self.network.critic2(new_states + [rand_actions])
q1_rand = q1_rand.reshape(rand_actions.shape)
q2_rand = q2_rand.reshape(rand_actions.shape)
q1_curr_actions = self.network.critic1(new_states + [curr_actions])
q2_curr_actions = self.network.critic2(new_states + [curr_actions])
q1_curr_actions = q1_curr_actions.reshape(curr_actions.shape)
q2_curr_actions = q2_curr_actions.reshape(curr_actions.shape)
q1_next_actions = self.network.critic1(new_states + [next_actions])
q2_next_actions = self.network.critic2(new_states + [next_actions])
q1_next_actions = q1_next_actions.reshape(curr_actions.shape)
q2_next_actions = q2_next_actions.reshape(curr_actions.shape)
cat_q1 = torch.cat([q1_rand - rand_density, q1_next_actions - next_probs.detach(), q1_curr_actions - curr_probs.detach()], 1)
cat_q2 = torch.cat([q2_rand - rand_density, q2_next_actions - next_probs.detach(), q2_curr_actions - curr_probs.detach()], 1)
min_q1_loss = torch.logsumexp(cat_q1, dim=0).mean() * self.alpha
min_q2_loss = torch.logsumexp(cat_q2, dim=0).mean() * self.alpha
min_q1_loss = min_q1_loss - q1_vals.mean() * self.alpha
min_q2_loss = min_q2_loss - q2_vals.mean() * self.alpha
q1_loss = td1_error + min_q1_loss
q2_loss = td2_error + min_q2_loss
self.critic1_optimizer.zero_grad()
q1_loss.backward(retain_graph=True)
self.critic1_optimizer.step()
self.critic2_optimizer.zero_grad()
q2_loss.backward(retain_graph=True)
self.critic2_optimizer.step()
self.actor_optimizer.zero_grad()
policy_loss.backward()
self.actor_optimizer.step()
self.network.update_target(self.network.critic1, self.network.target1)
self.network.update_target(self.network.critic2, self.network.target2)
policy_losses.append(policy_loss.detach().item())
q1_val_loss.append(q1_loss.detach().item())
q2_val_loss.append(q2_loss.detach().item())
entropies.append(entropy.detach().mean().item())
alpha_losses.append(alpha_loss.detach().item())
alphas.append(alpha.detach().item())
return policy_losses, q1_val_loss, q2_val_loss, entropies, alpha_losses, alphas