forked from joshhartmann11/BattleSnakeArena
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvalidator.py
More file actions
68 lines (45 loc) · 2.22 KB
/
validator.py
File metadata and controls
68 lines (45 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from game import GameParameters, Game
from snake import Snake
import torch
class Validator():
def run_validation(self, validation_config, game_config, num_validation_games) -> float:
# don't use gradients for validation runs for speed
with torch.no_grad():
opponent_name = validation_config["opponent"]
print("\nVALIDATING AGAINST OPPONENT: " + opponent_name)
validation_trainer = validation_config["trainer"]
validation_trainer.reset()
for i in range(num_validation_games):
game_results = self._run_validation_round(validation_config, game_config)
validation_trainer.print_training_result(game_results, i, num_validation_games)
mean_validation_reward = validation_trainer.total_collected_reward * 1.0 / num_validation_games
validation_results = {
"mean_validation_reward" : mean_validation_reward,
"win_rate" : validation_trainer.calculate_win_rate(validation_trainer.controller.nickname)
}
return validation_results
def _run_validation_round(self, validation_config, game_config) -> dict:
parameters = GameParameters(game_config)
controllers = validation_config["controllers"]
trainer = validation_config["trainer"]
snakes = []
validating_snake = None
for i in range(len(controllers)):
controller = controllers[i]
snake = Snake(controller.nickname, None, controller)
snakes.append(snake)
if (controller == validation_config["controller_under_valuation"]):
validating_snake = snake
game = Game(parameters, snakes)
is_done = game.reset()
game_results : dict = None
while not is_done:
# Perform a game step
is_done = game.step()
if (is_done):
# get final game results
game_results = game.get_game_results()
# determine reward for the controller
trainer.determine_reward(validating_snake, game_results)
trainer.finalize(game_results, validating_snake)
return game_results