-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathestimator.patch
More file actions
24 lines (24 loc) · 820 Bytes
/
estimator.patch
File metadata and controls
24 lines (24 loc) · 820 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
47c47
<
---
> import tensorflow as tf
228c228
< def _create_keras_model_fn(keras_model, custom_objects=None):
---
> def _create_keras_model_fn(keras_model, params=None, custom_objects=None):
239c239
< def model_fn(features, labels, mode):
---
> def model_fn(features, labels, mode, params=None):
448c448
< if keras_model._is_graph_network:
---
> if False:
462,464c462,464
< estimator = estimator_lib.Estimator(keras_model_fn,
< config=config,
< warm_start_from=warm_start_path)
---
> estimator = tf.contrib.tpu.TPUEstimator(keras_model_fn, use_tpu=False, train_batch_size=4, eval_batch_size=4,
> config=config,
> warm_start_from=warm_start_path)