-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreinforce.cc
More file actions
44 lines (36 loc) · 823 Bytes
/
reinforce.cc
File metadata and controls
44 lines (36 loc) · 823 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include "utils.h"
#include "agent.h"
#include "puck_world.h"
using std::string;
using std::vector;
using std::shared_ptr;
int main(int argc, char *argv[])
{
(void)argc;
(void)argv;
// srand(time(NULL));
srand(0);
math = shared_ptr<Math>(new MathCpu);
math->Init();
shared_ptr<PuckWorld> env(new PuckWorld);
shared_ptr<DQNAgent> agent(
new DQNAgent(env->GetNumStates(), env->GetMaxNumActions()));
float reward = 0;
for (int step = 0; step < 1000000; ++step)
{
shared_ptr<Mat> state = env->GetState();
int action = agent->Act(state);
int r = env->SampleNextState(action);
if (step > 0)
{
agent->Learn(r);
}
reward += r;
if (step % 1000 == 0 && step != 0)
{
printf("reward: %.3f\n", reward / 1000);
reward = 0;
}
}
return 0;
}