Skip to content
This repository was archived by the owner on Nov 28, 2025. It is now read-only.

Commit 791a42f

Browse files
aaudiberjhseu
andauthored
Add example of running the tf.data service in GKE. (#162)
Co-authored-by: Jonathan Hseu <vomjom@vomjom.net>
1 parent b28db03 commit 791a42f

5 files changed

Lines changed: 281 additions & 0 deletions

File tree

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
FROM tensorflow/tensorflow:nightly
2+
3+
COPY tf_std_data_server.py /
4+
ENTRYPOINT ["python", "-u", "/tf_std_data_server.py"]

data_service/README.md

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Distributed input processing with tf.data service.
2+
3+
This directory provides an example of running the tf.data service to
4+
horizontally scale tf.data input processing. We use GKE
5+
(Google Kubernetes Engine) to manage the tf.data servers.
6+
7+
This directory contains the following files:
8+
9+
- `Dockerfile.tf_std_data_server`: A dockerfile to build a tf.data server image.
10+
- `data_service.yaml.jinja`: A Jinja-templated Kubernetes definition for running
11+
tf.data service servers
12+
- `data_service_interfaces.yaml.jinja`: A Jinja-templated Kubernetes definition
13+
for creating load balancers which expose the tf.data service endpoints
14+
outside the GKE cluster (but within the same VPC network). This is needed
15+
for TPUs to be able to connect to servers running in GKE.
16+
- `tf_std_data_server.py`: A basic tf.data server implementation.
17+
18+
## Run the tf.data service in GKE
19+
20+
### Start a GKE cluster
21+
22+
If you don't already have a [GKE](https://cloud.google.com/kubernetes-engine)
23+
cluster, create one:
24+
25+
Replace `${CLUSTER_NAME}` with a name of your choice.
26+
Replace `${NUM_NODES}` with the number of tf.data service machines to run, e.g.
27+
`8`.
28+
Replace `${MACHINE_TYPE}` with the machine type to use, e.g. `e2-standard-4`
29+
30+
```
31+
gcloud container clusters create ${CLUSTER_NAME} --zone europe-west4-a \
32+
--scopes=cloud-platform --enable-ip-alias --num-nodes=${NUM_NODES} \
33+
--machine-type=${MACHINE_TYPE}
34+
```
35+
36+
`--enable-ip-alias` is needed to be able to connect to the cluster from a TPU.
37+
38+
### Create service endpoints
39+
40+
Set number of workers in `data_service_interfaces`
41+
Edit the variable at the start of `data_service_interfaces.yaml.jinja` to set the number of workers.
42+
{%- set workers = 8 -%}
43+
44+
Create data service endpoints so that the data service can be accessed from outside GKE.
45+
This requires `jinja2`, install it if you don't have it already: `pip3 install jinja2`.
46+
47+
```
48+
python3 ../render_template.py data_service_interfaces.yaml.jinja | kubectl apply -f -
49+
```
50+
51+
### Create tf.data server image
52+
53+
```
54+
docker build --no-cache -t gcr.io/${PROJECT_ID}/tf_std_data_server:latest \
55+
-f Dockerfile.tf_std_data_server .
56+
docker push gcr.io/${PROJECT_ID}/tf_std_data_server:latest
57+
```
58+
59+
### Start tf.data servers
60+
61+
Edit `data_service.yaml.jinja`, setting the image variable at the top of the
62+
file to the image created in the previous step, e.g.
63+
`"gcr.io/${PROJECT_ID}/tf_std_data_server:latest"`
64+
65+
Wait for GKE to assign endpoints for all services created in the "Create service
66+
endpoints" step. This may
67+
take a few minutes. The below command will query all worker endpoints:
68+
69+
```
70+
kubectl get services -o=jsonpath='{"\n"}{range .items[*]}"{.metadata.name}": "{.status.loadBalancer.ingress[*].ip}",{"\n"}{end}{"\n"}' | grep data-service-worker
71+
```
72+
73+
Once the command shows non-empty addresses for all workers, copy the output
74+
of the command into the `ip_mapping` variable at the start of `data_service.yaml.jinja`.
75+
76+
```
77+
{% set ip_mapping = {
78+
"data-service-worker-0": "10.164.0.40",
79+
"data-service-worker-1": "10.164.0.41",
80+
...
81+
} %}
82+
```
83+
84+
Now launch the tf.data servers:
85+
86+
```
87+
python3 ../render_template.py data_service.yaml.jinja | kubectl apply -f -
88+
```
89+
90+
The service is now ready to use. To find the service address, run
91+
92+
```
93+
kubectl get services data-service-master
94+
```
95+
96+
and examine the `EXTERNAL-IP` and `PORT(S)` columns. To access the cluster,
97+
you will use the string `'grpc://<EXTERNAL-IP>:<PORT>'`
98+
99+
## Run ResNet using the tf.data service for input.
100+
101+
The `classifier_trainer.py` script in the [TensorFlow Model
102+
Garden](https://github.com/tensorflow/models) supports using the tf.data service to
103+
get input data.
104+
105+
To run the script, do the following:
106+
107+
```
108+
git clone https://github.com/tensorflow/models.git
109+
cd models/official/vision/image_classification
110+
```
111+
112+
Edit either `configs/examples/resnet/imagenet/gpu.yaml` or
113+
`configs/examples/resnet/imagenet/tpu.yaml`,
114+
depending on whether you want to run on GPU or TPU. Under the `train_dataset`
115+
and `validation_dataset` sections, update `builder` from `'tfds'` to
116+
`'records'`. Then under the `train_dataset` section, add `tf_data_service:
117+
'grpc://<EXTERNAL_IP>:<PORT>'`.
118+
119+
Finally, run the ResNet model.
120+
121+
```
122+
export PYTHONPATH=/path/to/models
123+
python3 classifier_trainer.py \
124+
--mode=train_and_eval --model_type=resnet --dataset=imagenet --tpu=$TPU_NAME \
125+
--model_dir=$MODEL_DIR --data_dir=gs://cloud-tpu-test-datasets/fake_imagenet \
126+
--config_file=path/to/config
127+
```
128+
129+
## Restarting tf.data servers
130+
131+
tf.data servers are meant to live for the duration of a single training job.
132+
When starting a new job, you can use the following commands to stop the tf.data
133+
servers:
134+
135+
```
136+
kubectl get rs --no-headers=true | grep "data-service-" | xargs kubectl delete rs
137+
```
138+
139+
Then to start the servers again, run
140+
141+
```
142+
python3 ../render_template.py data_service.yaml.jinja | kubectl apply -f -
143+
```
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
{%- set image = "gcr.io/<project_id>/tf_std_data_server:latest" -%}
2+
{%- set port = 5050 -%}
3+
{% set ip_mapping = {
4+
} %}
5+
6+
kind: ReplicaSet
7+
apiVersion: extensions/v1beta1
8+
metadata:
9+
name: data-service-master
10+
spec:
11+
replicas: 1
12+
template:
13+
metadata:
14+
labels:
15+
name: data-service-master
16+
spec:
17+
containers:
18+
- name: tensorflow
19+
image: {{ image }}
20+
ports:
21+
- containerPort: {{ port }}
22+
args:
23+
- "--port={{ port }}"
24+
- "--is_master=true"
25+
---
26+
27+
{% for worker_name, worker_ip in ip_mapping.items() %}
28+
kind: ReplicaSet
29+
apiVersion: extensions/v1beta1
30+
metadata:
31+
name: {{ worker_name }}
32+
spec:
33+
replicas: 1
34+
template:
35+
metadata:
36+
labels:
37+
name: {{ worker_name }}
38+
spec:
39+
containers:
40+
- name: tensorflow
41+
image: {{ image }}
42+
ports:
43+
- containerPort: {{ port }}
44+
args:
45+
- "--port={{ port }}"
46+
- "--is_master=false"
47+
- "--master_address=data-service-master:{{ port }}"
48+
- "--worker_address={{ worker_ip }}:{{ port }}"
49+
---
50+
{% endfor %}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{%- set workers = 8 -%}
2+
{%- set port = 5050 -%}
3+
4+
kind: Service
5+
apiVersion: v1
6+
metadata:
7+
name: data-service-master
8+
annotations:
9+
cloud.google.com/load-balancer-type: "Internal"
10+
spec:
11+
type: LoadBalancer
12+
selector:
13+
name: data-service-master
14+
ports:
15+
- port: {{ port }}
16+
targetPort: {{ port }}
17+
protocol: TCP
18+
---
19+
{% for i in range(workers) %}
20+
kind: Service
21+
apiVersion: v1
22+
metadata:
23+
name: data-service-worker-{{ i }}
24+
annotations:
25+
cloud.google.com/load-balancer-type: "Internal"
26+
spec:
27+
type: LoadBalancer
28+
selector:
29+
name: data-service-worker-{{ i }}
30+
ports:
31+
- port: {{ port }}
32+
targetPort: {{ port }}
33+
protocol: TCP
34+
---
35+
{% endfor %}

data_service/tf_std_data_server.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Run a tf.data service server."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
22+
flags = tf.compat.v1.app.flags
23+
24+
flags.DEFINE_integer("port", 0, "Port to listen on")
25+
flags.DEFINE_bool("is_master", False, "Whether to start a master (as opposed to a worker server")
26+
flags.DEFINE_string("master_address", "", "The address of the master server. This is only needed when starting a worker server.")
27+
flags.DEFINE_string("worker_address", "", "The address of the worker server. This is only needed when starting a worker server.")
28+
29+
FLAGS = flags.FLAGS
30+
31+
32+
def main(unused_argv):
33+
if FLAGS.is_master:
34+
print("Starting tf.data service master")
35+
server = tf.data.experimental.service.MasterServer(
36+
port=FLAGS.port,
37+
protocol="grpc")
38+
else:
39+
print("Starting tf.data service worker")
40+
server = tf.data.experimental.service.WorkerServer(
41+
port=FLAGS.port,
42+
protocol="grpc",
43+
master_address=FLAGS.master_address,
44+
worker_address=FLAGS.worker_address)
45+
server.join()
46+
47+
48+
if __name__ == "__main__":
49+
tf.compat.v1.app.run()

0 commit comments

Comments
 (0)