Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 9 additions & 22 deletions 04_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class DQN:
def __init__(self, dim_state=None, num_action=None, discount=0.9):
self.discount = discount
self.Q = QNet(dim_state, num_action)
self.target_Q = QNet(dim_state, num_action)
self.target_Q.load_state_dict(self.Q.state_dict())

def get_action(self, state):
qvals = self.Q(state)
Expand All @@ -46,21 +44,13 @@ def get_action(self, state):
def compute_loss(self, s_batch, a_batch, r_batch, d_batch, next_s_batch):
# 计算s_batch,a_batch对应的值。
qvals = self.Q(s_batch).gather(1, a_batch.unsqueeze(1)).squeeze()
# 使用target Q网络计算next_s_batch对应的值。
next_qvals, _ = self.target_Q(next_s_batch).detach().max(dim=1)
# 使用原始的 Q网络计算next_s_batch对应的值。
next_qvals, _ = self.Q(next_s_batch).detach().max(dim=1)
# 使用MSE计算loss。
loss = F.mse_loss(r_batch + self.discount * next_qvals * (1 - d_batch), qvals)
return loss


def soft_update(target, source, tau=0.01):
"""
update target by target = tau * source + (1 - tau) * target.
"""
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)


@dataclass
class ReplayBuffer:
maxsize: int
Expand Down Expand Up @@ -126,7 +116,7 @@ def train(args, env, agent):
if np.random.rand() < epsilon or i < args.warmup_steps:
action = env.action_space.sample()
else:
action = agent.get_action(torch.from_numpy(state))
action = agent.get_action(torch.from_numpy(state).to(args.device))
action = action.item()
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
Expand Down Expand Up @@ -155,11 +145,11 @@ def train(args, env, agent):

if i > args.warmup_steps:
bs, ba, br, bd, bns = replay_buffer.sample(n=args.batch_size)
bs = torch.tensor(bs, dtype=torch.float32)
ba = torch.tensor(ba, dtype=torch.long)
br = torch.tensor(br, dtype=torch.float32)
bd = torch.tensor(bd, dtype=torch.float32)
bns = torch.tensor(bns, dtype=torch.float32)
bs = torch.tensor(bs, dtype=torch.float32).to(args.device)
ba = torch.tensor(ba, dtype=torch.long).to(args.device)
br = torch.tensor(br, dtype=torch.float32).to(args.device)
bd = torch.tensor(bd, dtype=torch.float32).to(args.device)
bns = torch.tensor(bns, dtype=torch.float32).to(args.device)

loss = agent.compute_loss(bs, ba, br, bd, bns)
loss.backward()
Expand All @@ -168,8 +158,6 @@ def train(args, env, agent):

log["loss"].append(loss.item())

soft_update(agent.target_Q, agent.Q)

# 3. 画图。
plt.plot(log["loss"])
plt.yscale("log")
Expand All @@ -191,7 +179,7 @@ def eval(args, env, agent):
state, _ = env.reset()
for i in range(5000):
episode_length += 1
action = agent.get_action(torch.from_numpy(state)).item()
action = agent.get_action(torch.from_numpy(state).to(args.device)).item()
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
env.render()
Expand Down Expand Up @@ -229,7 +217,6 @@ def main():
set_seed(args)
agent = DQN(dim_state=args.dim_state, num_action=args.num_action, discount=args.discount)
agent.Q.to(args.device)
agent.target_Q.to(args.device)

if args.do_train:
train(args, env, agent)
Expand Down
51 changes: 27 additions & 24 deletions 06_doubledqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,21 @@ def forward(self, obs):
class DoubleDQN:
def __init__(self, dim_obs=None, num_act=None, discount=0.9):
self.discount = discount
self.model = QNet(dim_obs, num_act)
self.target_model = QNet(dim_obs, num_act)
self.target_model.load_state_dict(self.model.state_dict())
self.Q = QNet(dim_obs, num_act)
self.target_Q = QNet(dim_obs, num_act)
self.target_Q.load_state_dict(self.Q.state_dict())

def get_action(self, obs):
qvals = self.model(obs)
qvals = self.Q(obs)
return qvals.argmax()

def compute_loss(self, s_batch, a_batch, r_batch, d_batch, next_s_batch):
# Compute current Q value based on current states and actions.
qvals = self.model(s_batch).gather(1, a_batch.unsqueeze(1)).squeeze()
qvals = self.Q(s_batch).gather(1, a_batch.unsqueeze(1)).squeeze()
# 选择行动
next_a_batch = self.Q(next_s_batch).argmax(dim=1)
# next state的value不参与导数计算,避免不收敛。
next_qvals, _ = self.target_model(next_s_batch).detach().max(dim=1)
next_qvals = self.target_Q(next_s_batch).gather(1, next_a_batch.unsqueeze(1)).squeeze().detach()
loss = F.mse_loss(r_batch + self.discount * next_qvals * (1 - d_batch), qvals)
return loss

Expand Down Expand Up @@ -96,7 +98,7 @@ def set_seed(args):

def train(args, env, agent):
replay_buffer = ReplayBuffer(100_000)
optimizer = torch.optim.Adam(agent.model.parameters(), lr=args.lr)
optimizer = torch.optim.Adam(agent.Q.parameters(), lr=args.lr)
optimizer.zero_grad()

epsilon = 1
Expand All @@ -107,16 +109,16 @@ def train(args, env, agent):
log_ep_rewards = []
log_losses = [0]

agent.model.train()
agent.target_model.train()
agent.model.zero_grad()
agent.target_model.zero_grad()
agent.Q.train()
agent.target_Q.train()
agent.Q.zero_grad()
agent.target_Q.zero_grad()
state, _ = env.reset()
for i in range(args.max_steps):
if np.random.rand() < epsilon or i < args.warmup_steps:
action = env.action_space.sample()
else:
action = agent.get_action(torch.from_numpy(state))
action = agent.get_action(torch.from_numpy(state).to(args.device))
action = action.item()
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
Expand All @@ -140,8 +142,8 @@ def train(args, env, agent):
print(f"i={i}, reward={episode_reward:.0f}, length={episode_length}, max_reward={max_episode_reward}, loss={log_losses[-1]:.1e}, epsilon={epsilon:.3f}")

if episode_length < 180 and episode_reward > max_episode_reward:
save_path = os.path.join(args.output_dir, "model.bin")
torch.save(agent.model.state_dict(), save_path)
save_path = os.path.join(args.output_dir, "Q.bin")
torch.save(agent.Q.state_dict(), save_path)
max_episode_reward = episode_reward

episode_reward = 0
Expand All @@ -150,11 +152,11 @@ def train(args, env, agent):

if i > args.warmup_steps:
bs, ba, br, bd, bns = replay_buffer.sample(n=args.batch_size)
bs = torch.tensor(bs, dtype=torch.float32)
ba = torch.tensor(ba, dtype=torch.long)
br = torch.tensor(br, dtype=torch.float32)
bd = torch.tensor(bd, dtype=torch.float32)
bns = torch.tensor(bns, dtype=torch.float32)
bs = torch.tensor(bs, dtype=torch.float32).to(args.device)
ba = torch.tensor(ba, dtype=torch.long).to(args.device)
br = torch.tensor(br, dtype=torch.float32).to(args.device)
bd = torch.tensor(bd, dtype=torch.float32).to(args.device)
bns = torch.tensor(bns, dtype=torch.float32).to(args.device)

loss = agent.compute_loss(bs, ba, br, bd, bns)
loss.backward()
Expand All @@ -164,7 +166,7 @@ def train(args, env, agent):
log_losses.append(loss.item())

# 更新目标网络。
for target_param, param in zip(agent.target_model.parameters(), agent.model.parameters()):
for target_param, param in zip(agent.target_Q.parameters(), agent.Q.parameters()):
target_param.data.copy_(args.lr_target * param.data + (1 - args.lr_target) * target_param.data)

plt.plot(log_losses)
Expand All @@ -179,16 +181,16 @@ def train(args, env, agent):

def eval(args, env, agent):
model_path = os.path.join(args.output_dir, "model.bin")
agent.model.load_state_dict(torch.load(model_path))
agent.Q.load_state_dict(torch.load(model_path))

episode_length = 0
episode_reward = 0

agent.model.eval()
agent.Q.eval()
state, _ = env.reset()
for i in range(5000):
episode_length += 1
action = agent.get_action(torch.from_numpy(state)).item()
action = agent.get_action(torch.from_numpy(state).to(args.device)).item()
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
episode_reward += reward
Expand Down Expand Up @@ -225,7 +227,8 @@ def main():
env = gym.make(args.env)
set_seed(args)
agent = DoubleDQN(dim_obs=args.dim_obs, num_act=args.num_act, discount=args.discount)
agent.model.to(args.device)
agent.Q.to(args.device)
agent.target_Q.to(args.device)

if args.do_train:
train(args, env, agent)
Expand Down
Loading