-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathgae.py
More file actions
28 lines (27 loc) · 1.2 KB
/
gae.py
File metadata and controls
28 lines (27 loc) · 1.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
def compute_advantages(rewards, values, dones, next_value, next_done, gamma, gae_lambda, is_gae, num_steps, device):
if is_gae:
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(num_steps)):
if t == num_steps - 1:
nextnonterminal = 1.0 - next_done
nextvalues = next_value
else:
nextnonterminal = 1.0 - dones[t + 1]
nextvalues = values[t + 1]
delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values
else:
returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(num_steps)):
if t == num_steps - 1:
nextnonterminal = 1.0 - next_done
next_return = next_value
else:
nextnonterminal = 1.0 - dones[t + 1]
next_return = returns[t + 1]
returns[t] = rewards[t] + gamma * nextnonterminal * next_return
advantages = returns - values
return advantages, returns