-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathchess_graph.py
More file actions
178 lines (167 loc) · 7.13 KB
/
chess_graph.py
File metadata and controls
178 lines (167 loc) · 7.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from typing import cast, List, Tuple
import jax
import jax.numpy as jnp
import numpy as np
import jraph as jraph
import pgx
import pgx.chess as pgc
import pgx.gardner_chess as pgg
from rich.pretty import pprint
def _state_nodes(observation: jnp.ndarray) -> Tuple[int, jnp.ndarray]:
assert(observation.ndim == 3)
n_row, n_col = observation.shape[:2]
# cell order is the same as FEN order
# features = observation.reshape((n_row * n_col, -1))
# cell order is the same as pgc.Action._from_label
features = jnp.rot90(observation, k=-1).reshape((n_row * n_col, -1))
if features.shape[-1] == 115: # Add dummy features for 8x8 compatibility
zeros = jnp.zeros(features.shape[:-1] + (119,))
zeros = zeros.at[:,:114].set(features[:,:114])
features = zeros.at[:,118].set(features[:,114])
assert(features.shape[-1] == 119)
return n_row * n_col, features
def _state_edges( # TODO: add self-edge
legal_action_mask: jnp.ndarray,
offset_id: int=1
) -> Tuple[int, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
assert(legal_action_mask.ndim == 1)
gardner = legal_action_mask.shape[-1] != 4672
size = 5 if gardner else 8
_pgc = pgg if gardner else pgc
# TODO: Missing opponent's moves (pawn moves, castling, ?)
all_moves = jax.vmap(_pgc.Action._from_label)(
jnp.arange(legal_action_mask.shape[-1])
)
real_moves_id = jnp.where(
all_moves.to != -1,
size=455 if gardner else 1858,
fill_value=-1
)[0] # The fill_value should never be used
action_edge_id = jnp.full(legal_action_mask.shape, -1) \
.at[real_moves_id] \
.set(jnp.arange(real_moves_id.shape[0]))
all_moves = jax.tree_map(lambda x: x[real_moves_id], all_moves)
edge_from = all_moves.from_ + offset_id
edge_to = all_moves.to + offset_id
n_features = 1+2+4+2*6 # legal, grid offsets, promotions, piece type
delta_file = (all_moves.to // size) - (all_moves.from_ // size)
delta_rank = (all_moves.to % size) - (all_moves.from_ % size)
edge_features = jnp.where(
(all_moves.to == -1)[:, None].repeat(n_features, axis=-1),
0,
jnp.stack([
# legal
legal_action_mask[real_moves_id],
# legal_action_mask,
# grid offsets
delta_file,
delta_rank,
# promotions
# doesn't distinguish between pawn promoting to queen and other
# pieces moving to the 8th rank in a way a pawn could
(
(all_moves.from_ % size == size-2)
& (all_moves.to % size == size-1)
& (jnp.abs(all_moves.to // size - all_moves.from_ // size) <= 1)
& (all_moves.underpromotion == 0)
), # queen
all_moves.underpromotion == 0, # rook
all_moves.underpromotion == 1, # bishop
all_moves.underpromotion == 2 # knight
] + sum([[ # TODO: add castling
( # pawn
(jnp.abs(delta_file) <= 1)
& (
(delta_rank == (1 if player else -1))
| (
(delta_rank == (2 if player else -2))
& (all_moves.from_ % size == (1 if player else size-2))
) # TODO: remove this case for gardner chess (no torpedo)
)
),
( # knight
(jnp.abs(delta_file) == 1) & (jnp.abs(delta_rank) == 2)
| (jnp.abs(delta_file) == 2) & (jnp.abs(delta_rank) == 1)
),
( # bishop
(jnp.abs(delta_file) == jnp.abs(delta_rank))
& (all_moves.underpromotion < 0)
),
( # rook
((jnp.abs(delta_file) == 0) | (jnp.abs(delta_rank) == 0))
& (all_moves.underpromotion < 0)
),
( # queen (might be useless as queen == bishop | rook)
(
(jnp.abs(delta_file) == jnp.abs(delta_rank))
| ((jnp.abs(delta_file) == 0) | (jnp.abs(delta_rank) == 0))
)
& (all_moves.underpromotion < 0)
),
( # king
((jnp.abs(delta_file) <= 1) & (jnp.abs(delta_rank) <= 1))
& (all_moves.underpromotion < 0)
)
] for player in range(2)], []), axis=-1)
)
return edge_from.shape[0], edge_from, edge_to, edge_features, action_edge_id
def state_to_graph(
observation: jnp.ndarray,
legal_action_mask: jnp.ndarray
) -> jraph.GraphsTuple:
n_nodes, node_features = jax.vmap(_state_nodes)(observation)
n_nodes = cast(jnp.ndarray, n_nodes)
node_features = node_features.reshape((-1, node_features.shape[-1]))
offsets = jnp.concatenate([
jnp.zeros(1, dtype=jnp.int32),
n_nodes[:-1]
]).cumsum()
n_edges, moves_from, moves_to, edges_features, action_edge_id = (
jax.vmap(_state_edges)(
legal_action_mask,
offset_id=cast(int, offsets)
)
)
moves_from = moves_from.reshape((-1,)).astype(jnp.int32)
moves_to = moves_to.reshape((-1,)).astype(jnp.int32)
edge_offsets = jnp.arange(action_edge_id.shape[0]) * edges_features.shape[1]
edges_features = edges_features.reshape((-1,) + edges_features.shape[2:])
edge_offsets = edge_offsets.repeat(action_edge_id.shape[1])
action_edge_id = action_edge_id.reshape((-1,))
action_edge_id = jnp.where(
action_edge_id == -1,
-1,
action_edge_id + edge_offsets
)
return jraph.GraphsTuple(
n_node=n_nodes,
nodes=node_features,
n_edge=cast(jnp.ndarray, n_edges),
edges=edges_features,
senders=moves_from,
receivers=moves_to,
globals=action_edge_id
)
def main():
env = pgx.make("gardner_chess")
state = env.init(jax.random.PRNGKey(0))
state = jax.tree_map(lambda x: x[None], state)
x = jax.jit(state_to_graph)(state.observation, state.legal_action_mask)
# pprint(x)
# pprint(x.n_node)
# pprint(x.nodes.shape)
# pprint(np.array(list(" PNBRQKpnbrqk"))[jnp.rot90((x.nodes[1:,:12] * jnp.arange(1, 13)).sum(axis=-1).reshape((-1,8,8)), axes=(1,2)).astype(jnp.int32)])
pprint(x.nodes.shape)
print(' l Δx Δy pq pr pb pn p n b r q k P N B R Q K ')
print(x.edges[jnp.where(state.legal_action_mask.reshape((-1,)))])
pprint(jnp.where(state.legal_action_mask.reshape((-1,))))
# states = jax.vmap(env.init)(jax.random.split(jax.random.PRNGKey(0), 2))
# x = jax.jit(state_to_graph)(states.observation, states.legal_action_mask)
# pprint((x.n_node, x.n_edge, x.nodes.shape, x.edges.shape))
# pprint(np.array(list(" PNBRQKpnbrqk"))[jnp.rot90((x.nodes[1:,:12] * jnp.arange(1, 13)).sum(axis=-1).reshape((-1,8,8)), axes=(1,2)).astype(jnp.int32)])
# print(' l Δx Δy pq pr pb pn p n b r q k P N B R Q K ')
# print(x.edges[jnp.where(state.legal_action_mask.reshape((-1)))])
# print(x.receivers[jnp.where(states.legal_action_mask.reshape((-1,)))])
# pprint((x.receivers.shape, x.receivers.min(), x.receivers.max()))
if __name__ == "__main__":
main()