-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcheck_trained_data.py
More file actions
73 lines (49 loc) · 1.56 KB
/
check_trained_data.py
File metadata and controls
73 lines (49 loc) · 1.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import pandas as pd
import pickle
from agents import *
def print_pkl():
# .pkl 파일 경로 설정
file_path = '/root/EMPD/results/q_learning_history3_copykitten_episode_5000_max_len_1_epsilon_1_replace_0/q_table.pkl'
# .pkl 파일 열기
with open(file_path, 'rb') as file:
data = pickle.load(file)
# 데이터가 DataFrame인지 확인하고 출력
print('q_learning_history3: ',len(data))
print(' \n\n\n ', data)
def print_dqn_result():
his_len = 10
checkpoints = '/root/EMPD/results/0609_dqn_multi_episode_3000_max_len_5_epsilon_1_replace_0/checkpoints.pt'
history = {}
no = [1,0,0]
cc = [0,1,1]
cb = [0,1,0]
bc = [0,0,1]
bb = [0,0,0]
back = [cc,cc]
his_len -= len(back)
history[1] = [no] * his_len + back
dqn = DQN(f"DQN {0}", 0)
dqn.load_model(path = checkpoints)
b, c = dqn.check_network(1, history, epsilon=0)
print("Betray: ", b.item())
print("Cooperate: ",c.item())
def linear_probing():
his_len = 10
checkpoints = '/root/EMPD/results/0609_dqn_multi_episode_3000_max_len_5_epsilon_1_replace_0/checkpoints.pt'
history = {}
no = [1,0,0]
cc = [0,1,1]
cb = [0,1,0]
bc = [0,0,1]
bb = [0,0,0]
back = [cc,cc]
his_len -= len(back)
history[1] = [no] * his_len + back
dqn = DQN(f"DQN {0}", 0)
dqn.load_model(path = checkpoints)
b, c = dqn.check_network(1, history, epsilon=0)
print("Betray: ", b.item())
print("Cooperate: ",c.item())
if __name__ == "__main__":
# print_dqn_result()
print_pkl()