-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathGameInNet2.py
More file actions
41 lines (31 loc) · 1.1 KB
/
GameInNet2.py
File metadata and controls
41 lines (31 loc) · 1.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import cv2
import torch
import numpy as np
from model.model import *
import argparse
args=argparse.ArgumentParser()
args.add_argument('--device', type=str, default='cuda:0')
opt=args.parse_args()
def on_EVENT_LBUTTONDOWN(event, x, y, flags, param):
global mouse_pos
if event == cv2.EVENT_LBUTTONDOWN:
mouse_pos=[y,x]
cv2.imshow("GAME",)
mouse_pos=[200,200]
model=torch.load("GameInNet.pkl")
cv2.namedWindow("GAME")
cv2.setMouseCallback("GAME", on_EVENT_LBUTTONDOWN)
hid_state=torch.tensor([[[8,8,17,17,17,17]]]).repeat((1, 30, 1)).float()
hid_state=hid_state.cuda()
model.cuda()
while 1:
print(hid_state,torch.tensor([mouse_pos]).float().to(opt.device)/256*64)
image,hid_out=model(hid_state,torch.tensor([mouse_pos]).float().to(opt.device)/256*64)
hid_state=hid_out
imgout = image.detach().cpu().squeeze().numpy()
imgout = ((imgout / 2 + 0.5) * 255).astype(np.uint8)
imgout = cv2.cvtColor(imgout, cv2.COLOR_GRAY2BGR)
imgout[mouse_pos[0]//4, mouse_pos[1]//4, 2] = 255
imgout = cv2.resize(imgout, (256, 256))
cv2.imshow('GAME', imgout)
cv2.waitKey(100)