-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsearch.py
More file actions
75 lines (65 loc) · 2.64 KB
/
search.py
File metadata and controls
75 lines (65 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import argparse
import pymimir as mm
import pymimir_rgnn as rgnn
import torch
from pathlib import Path
from utils import create_device
class NeuralHeuristic(mm.Heuristic):
def __init__(self, model: rgnn.RelationalGraphNeuralNetwork):
super().__init__()
self._model = model
def compute_value(self, state: mm.State, is_goal_state: bool) -> float:
if is_goal_state: return 0.0
with torch.no_grad():
self._model.eval()
problem = state.get_problem()
goal = problem.get_goal_condition()
value = self._model.forward([(state, goal)]).readout('value')[0]
return value
def _parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Settings for testing')
parser.add_argument('--domain', required=True, type=Path, help='Path to the domain file')
parser.add_argument('--problem', required=True, type=Path, help='Path to the problem file')
parser.add_argument('--model', required=True, type=Path, help='Path to a pre-trained model')
args = parser.parse_args()
return args
def _main(args: argparse.Namespace) -> None:
print(f'Torch: {torch.__version__}')
domain = mm.Domain(args.domain)
problem = mm.Problem(domain, args.problem)
print(f'Loading model... ({args.model})')
device = create_device()
model, _ = rgnn.RelationalGraphNeuralNetwork.load(domain, args.model, device)
initial_state = problem.get_initial_state()
neural_heuristic = NeuralHeuristic(model)
# Initialize counters for statistics.
num_expanded = 0
num_generated = 0
def increment_expanded(state):
nonlocal num_expanded
num_expanded += 1
def increment_generated(state, action, cost, successor_state):
nonlocal num_generated
num_generated += 1
def print_f_layer(f: float):
print(f'[f={f:.3f}] Expanded: {num_expanded}, Generated: {num_generated}')
# Start the A* search with eager evaluation.
result = mm.astar_eager(
problem,
initial_state,
neural_heuristic,
on_expand_state=increment_expanded,
on_generate_state=increment_generated,
on_finish_f_layer=print_f_layer,
)
# Print the statistics.
print(f'[Final] Expanded: {num_expanded}, Generated: {num_generated}')
# Print the result of the search.
if result.status == "solved":
print(f'Found a solution of length {len(result.solution)}!')
for index, action in enumerate(result.solution):
print(f'{index + 1:>4}: {str(action)}')
else:
print('Failed to find a solution!')
if __name__ == '__main__':
_main(_parse_arguments())