-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmainRun.py
More file actions
312 lines (221 loc) · 9.59 KB
/
mainRun.py
File metadata and controls
312 lines (221 loc) · 9.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import argparse
import datetime
import os
# import files for the three different modules of the pipeline
import matplotlib
# import rlmethods
# import featureExtractor
'''
this should be the main function that takes in the 3 modules needed to run the
pipeline:
1.features
2.Rl part
3.IRL part
Also, it should make passing arguments easy and convenient.
Possible arguments:
1. New run/Load?
2. With or without display
3. Environment arguments:
4. Features to be used.
4.1 If local window, size of the window.
5. RL method to be used
5.1 Number of iterations till convergence.
6. IRL method to be used
'''
def read_arguments():
"""
Reads argments passed from command line.
"""
parser = argparse.ArgumentParser(
description='Enter arguments to run the pipeline.')
# arguments for external files that might be necessary to run the program
parser.add_argument(
'--cost_network', type=str,
help='file storing the state dictionary of the cost network.'
)
parser.add_argument(
'--policy_network', type=str,
help='File storing the state dictionary of the Policy network.'
)
parser.add_argument(
'--state_dictionary', type=str,
help='Environment on which to run the algo (obstacle/no obstacle)'
)
parser.add_argument(
'--expert_trajectory_file', type=str,
help='Path to file containing the exeprt trajectories.')
# network hyper parameters
parser.add_argument(
'--cost_network_input', type=int, default=29,
help='layer size of cost network. None if you have specified cost \
network state dict.')
parser.add_argument(
'--cost_network_hidden', nargs='+', type=int, default=[256, 256],
help='Hidden size of cost network.None if you have specified cost \
network state dict.')
parser.add_argument(
'--cost_network_output', type=int, default=1,
help='Output layer size of cost network.None if you have specified \
cost network state dict.')
parser.add_argument(
'--policy_network_input', type=int, default=29,
help='Input layer size of policy network.None if you have specified \
policy network state dict.')
parser.add_argument(
'--policy_network_hidden', nargs='+', type=int, default=[256, 256],
help='Hidden layer size of policy network.None if you have specified \
policy network state dict.')
parser.add_argument(
'--policy_network_output', type=int, default=4,
help='Output layer size of policy network.None if you have specified \
policy network state dict.')
# other run hyper parameters like optimizer and all???
# run hyperparameters
parser.add_argument('--irl_iterations', type=int,
help='Number of times to iterate over the IRL part.')
parser.add_argument(
'--no_of_samples', type=int,
help='Number of samples to create agent state visitation frequency.')
parser.add_argument(
'--rl_iterations', type=int,
help='Number of iterations to be performed in the RL section.')
# arguments for the I/O of the program
parser.add_argument(
'--display_board', type=str, default='False',
help='If True, draw envirnment.')
parser.add_argument(
'--on_server', type=str, default='True',
help='False if program is to run on server.')
parser.add_argument('--store_results', type=str, default='True')
parser.add_argument(
'--plot_interval', type=int, default=10,
help='Iterations before loss and reward curve plots are stored.')
parser.add_argument(
'--savedict_policy_interval', type=int, default=100,
help='Iterations after which the policy network will be stored.')
parser.add_argument(
'--savedict_cost_interval', type=int, default=1,
help='Iterations after which the cost network will be stored.')
# arguments for the broader pipeLine
parser.add_argument(
'--rl_method', type=str,
help='Enter the RL method to be used.')
parser.add_argument(
'--feature_space', type=str,
help='Type of features to be used to get the state of the agent.')
parser.add_argument('--irl_method', type=str,
help='Enter the IRL method to be used.')
parser.add_argument(
'--run_type', type=str, default='train',
help='Enter if it is a train run or a test run.(train/test).')
parser.add_argument(
'--verbose', type=str, default='False',
help='Set verbose to "True" to get a myriad of print statements crowd\
your terminal. Necessary information should be provided with either\
of the modes.')
parser.add_argument(
'--no_of_testRuns', type=int, default=0,
help='If --run_type set to test, then this denotes the number of test \
runs you want to conduct.')
_args = parser.parse_args()
return _args
def dictToFilename(dict):
filestr = ''
for key in dict.keys():
filestr += str(dict[key])
filestr += '_'
return filestr
def assertargs(args):
# add assertions later
return 0
def arrangeDirForStorage(irlMethod, rlMethod, costNNparams, policyNNparams):
storageDict = {}
curDay = str(datetime.datetime.now().date())
curtime = str(datetime.datetime.now().time())
basePath = 'saved-models-irl/'
subPathPolicy = curDay+'/'+curtime+'/'+'PolicyNetwork/'
subPathCost = curDay+'/'+curtime+'/'+'CostNetwork/'
curDirPolicy = basePath + subPathPolicy
curDirCost = basePath + subPathCost
fileNamePolicy = irlMethod+'-'+rlMethod+'-'+dictToFilename(policyNNparams)
fileNameCost = irlMethod+'-'+rlMethod+'-'+dictToFilename(costNNparams)
if not os.path.exists(curDirPolicy):
os.makedirs(curDirPolicy)
if not os.path.exists(curDirCost):
os.makedirs(curDirCost)
storageDict['basepath'] = basePath+curDay+'/'+curtime+'/'
storageDict['costDir'] = curDirCost
storageDict['policyDir'] = curDirPolicy
storageDict['costFilename'] = fileNameCost
storageDict['policyFilename'] = fileNamePolicy
return storageDict
def parseBool(stringarg):
if stringarg == 'True':
return True
if stringarg == 'False':
return False
return -1
'''
example running statements:
python mainRun.py - -state_dictionary 'no obstacle' - -display_board 'False' - -on_server 'True' - -expert_trajectory_file 'expertstateinfolong_50.npy' - -irl_iterations 10 - -no_of_samples 100 - -rl_iterations 200 - -rl_method = 'Actor_Critic' - -irl_method = 'DeepMaxEnt' - -run_type 'train' - -cost_network '/home/abhisek/Study/Robotics/deepirl/saved-models-irl/2019-01-31/16:40:33.387966/CostNetwork/DeepMaxEnt-Actor_Critic-29_[256, 256]_1_iteration_0.h5' - -policy_network '/home/abhisek/Study/Robotics/deepirl/saved-models-irl/2019-01-31/16:40:33.387966/PolicyNetwork/DeepMaxEnt-Actor_Critic-29_[256, 256]_4_iterEND_1.h5' - -no_of_testRuns 0
'''
if __name__ == '__main__':
args = read_arguments()
# batch of mandatory arguments
features = args.feature_space
rlMethod = args.rl_method
IRLMethod = args.irl_method
demofile = args.expert_trajectory_file
saveInfo = parseBool(args.store_results)
display = parseBool(args.display_board)
onServer = parseBool(args.on_server)
verbose = parseBool(args.verbose)
runType = args.run_type
testRuns = args.no_of_testRuns
# batch of conditions
costNetwork = args.cost_network
policyNetwork = args.policy_network
costNNparams = {}
policyNNparams = {}
costNNparams['input'] = args.cost_network_input
costNNparams['hidden'] = args.cost_network_hidden
costNNparams['output'] = args.cost_network_output
policyNNparams['input'] = args.policy_network_input
policyNNparams['hidden'] = args.policy_network_hidden
policyNNparams['output'] = args.policy_network_output
irlIterations = args.irl_iterations
sampling_no = args.no_of_samples
rlIterations = args.rl_iterations
plotIntervals = args.plot_interval
rlModelStoreInterval = args.savedict_policy_interval
irlModelStoreInterval = args.savedict_cost_interval
# have to put this in the pipeline of the code, not touching this as of yet
typeofEnvironment = args.state_dictionary
print saveInfo
if saveInfo:
storageInfoDict = arrangeDirForStorage(
IRLMethod, rlMethod, costNNparams, policyNNparams)
else:
storageInfoDict = None
if onServer:
matplotlib.use('Agg')
import deepirl
stateDict, _ = deepirl.getstateDict(args.state_dictionary)
maxEntIrl = deepirl.DeepMaxEntIRL(demofile, rlMethod, costNNparams,
costNetwork, policyNNparams,
policyNetwork, irlIterations,
sampling_no, rlIterations,
store=saveInfo,
storeInfo=storageInfoDict,
render=display, onServer=onServer,
resultPlotIntervals=plotIntervals,
irlModelStoreInterval=irlModelStoreInterval,
rlModelStoreInterval=rlModelStoreInterval,
testIterations=testRuns, verbose=verbose)
if runType == 'train':
maxEntIrl.runDeepMaxEntIRL()
if runType == 'test':
print 'Starting test branch . . .'
maxEntIrl.testMaxDeepIRL()
else:
print 'I have not coded this option yet.'