Skip to content

04_dqn.py has some bug #4

@l1351868270

Description

@l1351868270

When i run the file use the command: python .\04_dqn.py --do_train

Traceback (most recent call last):
  File "E:\code\github\wangshusen\DeepRL-Chinese\04_dqn.py", line 242, in <module>
    main()
  File "E:\code\github\wangshusen\DeepRL-Chinese\04_dqn.py", line 235, in main
    train(args, env, agent)
  File "E:\code\github\wangshusen\DeepRL-Chinese\04_dqn.py", line 129, in train
    action = agent.get_action(torch.from_numpy(state))
  File "E:\code\github\wangshusen\DeepRL-Chinese\04_dqn.py", line 43, in get_action
    qvals = self.Q(state)
  File "E:\develop\anaconda3\envs\ray\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "E:\code\github\wangshusen\DeepRL-Chinese\04_dqn.py", line 29, in forward
    x = F.relu(self.fc1(state))
  File "E:\develop\anaconda3\envs\ray\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "E:\develop\anaconda3\envs\ray\lib\site-packages\torch\nn\modules\linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_mm)

The file should become like this:

action = agent.get_action(torch.from_numpy(state))

action = agent.get_action(torch.from_numpy(state).to(args.device))

DeepRL-Chinese/04_dqn.py

Lines 158 to 162 in 55a9f5e

bs = torch.tensor(bs, dtype=torch.float32)
ba = torch.tensor(ba, dtype=torch.long)
br = torch.tensor(br, dtype=torch.float32)
bd = torch.tensor(bd, dtype=torch.float32)
bns = torch.tensor(bns, dtype=torch.float32)

bs = torch.tensor(bs, dtype=torch.float32, device=args.device)
ba = torch.tensor(ba, dtype=torch.long, device=args.device)
br = torch.tensor(br, dtype=torch.float32, device=args.device)
bd = torch.tensor(bd, dtype=torch.float32, device=args.device)
bns = torch.tensor(bns, dtype=torch.float32, device=args.device)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions