-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathalphaZeroParallel.py
More file actions
101 lines (77 loc) · 4.11 KB
/
alphaZeroParallel.py
File metadata and controls
101 lines (77 loc) · 4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import random
import torch
import torch.nn.functional as F
from tqdm import trange
from mctsParallel import MCTSParallel
class AlphaZeroParallel:
def __init__(self, model, optimizer, game, args):
self.model = model
self.optimizer = optimizer
self.game = game
self.args = args
self.mcts = MCTSParallel(model, game, args)
def selfPlay(self):
return_memory = []
player = 1
spGames = [SelfPlayGame(self.game) for g in range(self.args['num_parallel_games'])]
while len(spGames) > 0:
states = np.stack([g.state for g in spGames])
neutral_states = self.game.change_perspective(states, player)
self.mcts.search(neutral_states, spGames)
for i in range(len(spGames))[::-1]:
g = spGames[i]
action_probs = np.zeros(self.game.action_size)
for child in g.root.children:
action_probs[child.action_taken] = child.visit_count
action_probs /= np.sum(action_probs)
g.memory.append((g.root.state, action_probs, player))
temperature_action_probs = action_probs ** (1 / self.args['temperature'])
temperature_action_probs /= np.sum(temperature_action_probs)
action = np.random.choice(self.game.action_size, p=temperature_action_probs)
g.state = self.game.get_next_state(g.state, action, player)
value, is_terminal = self.game.get_value_and_terminated(g.state, action)
if is_terminal:
for hist_neutral_state, hist_action_probs, hist_player in g.memory:
hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
return_memory.append((
self.game.get_encoded_state(hist_neutral_state),
hist_action_probs,
hist_outcome
))
del spGames[i]
player = self.game.get_opponent(player)
return return_memory
def train(self, memory):
random.shuffle(memory)
for batchIdx in range(0, len(memory), self.args['batch_size']):
sample = memory[batchIdx:batchIdx+self.args['batch_size']]
state, policy_targets, value_targets = zip(*sample)
state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)
state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)
out_policy, out_value = self.model(state)
policy_loss = F.cross_entropy(out_policy, policy_targets)
value_loss = F.mse_loss(out_value, value_targets)
loss = policy_loss + value_loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def learn(self):
for iteration in range(self.args['num_iterations']):
memory = []
self.model.eval()
for selfPlay_iteration in trange(self.args['num_selfPlay_iterations'] // self.args['num_parallel_games']):
memory += self.selfPlay()
self.model.train()
for epoch in trange(self.args['num_epochs']):
self.train(memory)
torch.save(self.model.state_dict(), f"Models/{self.game}/model_{iteration}.pt")
torch.save(self.optimizer.state_dict(), f"Models/{self.game}/optimizer_{iteration}.pt")
class SelfPlayGame:
def __init__(self, game):
self.state = game.get_initial_state()
self.memory = []
self.root = None
self.node = None