-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvisualize_agent.py
More file actions
105 lines (74 loc) · 2.95 KB
/
visualize_agent.py
File metadata and controls
105 lines (74 loc) · 2.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import argparse
import torch
import time
from gym import spaces
import utils
parser = argparse.ArgumentParser()
parser.add_argument("--device", default=None, type=str)
parser.add_argument("--model_dir", default="full_agent", type=str)
parser.add_argument("--ltl_sampler", default="Dataset_e54test_no-shuffle", type=str)
parser.add_argument("--seed", default=1, type=int)
parser.add_argument("--formula_id", default=0, type=int)
args = parser.parse_args()
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device(device)
utils.set_seed(args.seed)
REPO_DIR = os.path.dirname(os.path.abspath(__file__))
# compute agent dir
storage_dir = os.path.join(REPO_DIR, "storage")
agent_dir = os.path.join(storage_dir, args.model_dir)
# load training config
config = utils.load_config(agent_dir)
print(f"\nConfig:\n{config}")
# load training status
status = utils.get_status(agent_dir, device)
# build environment
env = utils.make_env(config.env, config.progression_mode, args.ltl_sampler, args.seed, config.int_reward, config.noLTL,
config.state_type, None, config.obs_size, config.max_num_steps)
obs_shape = env.observation_space['features'].shape
num_grounder_classes = len(env.propositions) + 1
action_to_str = {0:"down", 1:"right", 2:"up", 3:"left"}
# set formula
env.sampler.sampled_tasks = args.formula_id
# create and load grounder
sym_grounder = utils.make_grounder(config.grounder_model, num_grounder_classes, obs_shape, True)
sym_grounder.load_state_dict(status["grounder_state"]) if sym_grounder is not None else None
sym_grounder.to(device) if sym_grounder is not None else None
env.env.sym_grounder = sym_grounder
agent = utils.Agent(env, config.env, env.propositions, env.observation_space, env.action_space, agent_dir,
config.ignoreLTL, config.progression_mode, config.gnn_model, config.recurrence, config.compositional,
config.dumb_ac, device, False, 1, False)
# TEST
obs = env.reset()
done = False
step = 0
while not done:
step += 1
env.show()
time.sleep(0.5)
print(f"\n---")
print(f"Step: {step}")
print(f"Predicted Residual Task:")
utils.pprint_ltl_formula(env.translate_formula(env.pred_ltl_goal))
if env.real_ltl_goal != env.pred_ltl_goal:
print("WRONG PREDICTED RESIDUAL FORMULA")
action = agent.get_action(obs)
if isinstance(env.action_space, spaces.Discrete):
action = action.item()
print(f"\nAction: {action_to_str[action]}")
obs, reward, done, info = env.step(action)
real_sym = env.env.translate_formula(env.env.get_real_events())
print(f"Real Symbol: {real_sym}")
pred_sym = env.env.translate_formula(env.env.get_events())
print(f"Pred Symbol: {pred_sym}")
if env.env.get_events() != env.env.get_real_events():
print("WRONG PREDICTION")
print(f"Reward: {reward}")
if done:
break
env.show()
print("Done!")
print("Closing...")
time.sleep(2.0)
env.close()