-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreward_func.py
More file actions
76 lines (67 loc) · 2.56 KB
/
reward_func.py
File metadata and controls
76 lines (67 loc) · 2.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import tensorflow.keras.backend as K
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, BatchNormalization, Add, Lambda, concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam, SGD
def evaluator_net(config):
"""
This function creates the evaluator network with keras
:param obs: The number of possible observations in the environment
:param actions: the number of possible action in the environment
:return: a model of the neural network
"""
units = config.units
obs_env = config.obs_env
obs_det = config.obs_det
actions = config.action_size
inp_env = Input((obs_env,))
inp_det = Input((obs_det,))
inp_act = Input((3,))
inp = concatenate([inp_env, inp_det, inp_act])
x_skip = Dense(units, activation='relu')(inp)
x = BatchNormalization()(x_skip)
x = Dense(units, activation='relu')(x)
x = BatchNormalization()(x)
x = Dense(units, activation='relu')(x)
x = BatchNormalization()(x)
x = Add()([x, x_skip])
x = BatchNormalization()(x)
x_skip2 = Dense(units, activation='relu')(x)
x = BatchNormalization()(x_skip2)
x = Dense(units, activation='relu')(x)
x = BatchNormalization()(x)
x = Dense(units, activation='relu')(x)
x = BatchNormalization()(x)
x = Add()([x, x_skip2])
x = BatchNormalization()(x)
x_skip3 = Dense(units, activation='relu')(x)
x = BatchNormalization()(x_skip3)
x = Dense(units, activation='relu')(x)
x = BatchNormalization()(x)
x = Dense(units, activation='relu')(x)
x = BatchNormalization()(x)
x = Add()([x, x_skip3])
x = BatchNormalization()(x)
x_skip4 = Dense(units, activation='relu')(x)
x = BatchNormalization()(x_skip4)
x = Dense(units, activation='relu')(x)
x = BatchNormalization()(x)
x = Dense(units, activation='relu')(x)
x = BatchNormalization()(x)
x = Add()([x, x_skip4])
x = BatchNormalization()(x)
x_ska = Dense(units, activation='relu')(x)
x_a = BatchNormalization()(x_ska)
x_a = Dense(units, activation='relu')(x_a)
x_a = BatchNormalization()(x_a)
x_a = Dense(units, activation='relu')(x_a)
x_a = BatchNormalization()(x_a)
x_a = Add()([x_a, x_ska])
x_a = BatchNormalization()(x_a)
x_a = Dense(units, activation='relu')(x_a)
x_a = BatchNormalization()(x_a)
out_reward = Dense(1, activation="linear")(x_a)
m = Model([inp_env, inp_det, inp_act], [out_reward])
m.compile(optimizer=Adam(config.learning_rate, decay=config.lr_decay),
loss="MSE", metrics=["MAE"])
return m