Skip to content

Commit ba30940

Browse files
committed
chore(dataflow/gemma): update dependencies and format code
- Update tensorflow base image to 2.20.0-gpu and beam sdk to 3.11/2.74.0 - Update apache_beam, keras, keras_nlp, and protobuf dependencies - Update test dependencies including google-cloud-aiplatform, storage, and pytest - Format custom_model_gemma.py and e2e_test.py - Update ignored python versions in noxfile_config.py
1 parent e95b223 commit ba30940

6 files changed

Lines changed: 49 additions & 45 deletions

File tree

dataflow/gemma/Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# This uses Ubuntu with Python 3.11
1616
# You can check the Python version for a given tensorflow
1717
# container at https://hub.docker.com/r/tensorflow/tensorflow/tags
18-
ARG SERVING_BUILD_IMAGE=tensorflow/tensorflow:2.16.1-gpu
18+
ARG SERVING_BUILD_IMAGE=tensorflow/tensorflow:2.20.0-gpu
1919

2020
FROM ${SERVING_BUILD_IMAGE}
2121

@@ -29,7 +29,7 @@ RUN pip install --upgrade --no-cache-dir pip \
2929
&& pip install --no-cache-dir -r requirements.txt
3030

3131
# Copy files from official SDK image, including script/dependencies.
32-
COPY --from=apache/beam_python3.14_sdk:2.73.0 /opt/apache/beam /opt/apache/beam
32+
COPY --from=apache/beam_python3.11_sdk:2.74.0 /opt/apache/beam /opt/apache/beam
3333

3434
# Copy the model directory downloaded from Kaggle and the pipeline code.
3535
COPY gemma_2b gemma_2B

dataflow/gemma/custom_model_gemma.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
self,
3636
model_name: str = "gemma_2B",
3737
):
38-
""" Implementation of the ModelHandler interface for Gemma using text as input.
38+
"""Implementation of the ModelHandler interface for Gemma using text as input.
3939
4040
Example Usage::
4141
@@ -48,7 +48,7 @@ def __init__(
4848
self._env_vars = {}
4949

5050
def share_model_across_processes(self) -> bool:
51-
""" Indicates if the model should be loaded once-per-VM rather than
51+
"""Indicates if the model should be loaded once-per-VM rather than
5252
once-per-worker-process on a VM. Because Gemma is a large language model,
5353
this will always return True to avoid OOM errors.
5454
"""
@@ -62,7 +62,7 @@ def run_inference(
6262
self,
6363
batch: Sequence[str],
6464
model: GemmaCausalLM,
65-
inference_args: Optional[dict[str, Any]] = None
65+
inference_args: Optional[dict[str, Any]] = None,
6666
) -> Iterable[PredictionResult]:
6767
"""Runs inferences on a batch of text strings.
6868
@@ -85,7 +85,8 @@ def run_inference(
8585
class FormatOutput(beam.DoFn):
8686
def process(self, element, *args, **kwargs):
8787
yield "Input: {input}, Output: {output}".format(
88-
input=element.example, output=element.inference)
88+
input=element.example, output=element.inference
89+
)
8990

9091

9192
if __name__ == "__main__":
@@ -119,13 +120,16 @@ def process(self, element, *args, **kwargs):
119120

120121
pipeline = beam.Pipeline(options=beam_options)
121122
_ = (
122-
pipeline | "Read Topic" >>
123-
beam.io.ReadFromPubSub(subscription=args.messages_subscription)
123+
pipeline
124+
| "Read Topic"
125+
>> beam.io.ReadFromPubSub(subscription=args.messages_subscription)
124126
| "Parse" >> beam.Map(lambda x: x.decode("utf-8"))
125-
| "RunInference-Gemma" >> RunInference(
127+
| "RunInference-Gemma"
128+
>> RunInference(
126129
GemmaModelHandler(args.model_path)
127130
) # Send the prompts to the model and get responses.
128131
| "Format Output" >> beam.ParDo(FormatOutput()) # Format the output.
129-
| "Publish Result" >>
130-
beam.io.gcp.pubsub.WriteStringsToPubSub(topic=args.responses_topic))
132+
| "Publish Result"
133+
>> beam.io.gcp.pubsub.WriteStringsToPubSub(topic=args.responses_topic)
134+
)
131135
pipeline.run()

dataflow/gemma/e2e_test.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
NOTE: For the tests to find the conftest in the testing infrastructure,
4040
add the PYTHONPATH to the "env" in your noxfile_config.py file.
4141
"""
42+
4243
from collections.abc import Callable, Iterator
4344

4445
import conftest # python-docs-samples/dataflow/conftest.py
@@ -70,8 +71,9 @@ def messages_topic(pubsub_topic: Callable[[str], str]) -> str:
7071

7172

7273
@pytest.fixture(scope="session")
73-
def messages_subscription(pubsub_subscription: Callable[[str, str], str],
74-
messages_topic: str) -> str:
74+
def messages_subscription(
75+
pubsub_subscription: Callable[[str, str], str], messages_topic: str
76+
) -> str:
7577
return pubsub_subscription("messages", messages_topic)
7678

7779

@@ -81,20 +83,21 @@ def responses_topic(pubsub_topic: Callable[[str], str]) -> str:
8183

8284

8385
@pytest.fixture(scope="session")
84-
def responses_subscription(pubsub_subscription: Callable[[str, str], str],
85-
responses_topic: str) -> str:
86+
def responses_subscription(
87+
pubsub_subscription: Callable[[str, str], str], responses_topic: str
88+
) -> str:
8689
return pubsub_subscription("responses", responses_topic)
8790

8891

8992
@pytest.fixture(scope="session")
9093
def dataflow_job(
91-
project: str,
92-
bucket_name: str,
93-
location: str,
94-
unique_name: str,
95-
container_image: str,
96-
messages_subscription: str,
97-
responses_topic: str,
94+
project: str,
95+
bucket_name: str,
96+
location: str,
97+
unique_name: str,
98+
container_image: str,
99+
messages_subscription: str,
100+
responses_topic: str,
98101
) -> Iterator[str]:
99102
# Launch the streaming Dataflow pipeline.
100103
conftest.run_cmd(
@@ -127,20 +130,18 @@ def dataflow_job(
127130

128131
@pytest.mark.timeout(3600)
129132
def test_pipeline_dataflow(
130-
project: str,
131-
location: str,
132-
dataflow_job: str,
133-
messages_topic: str,
134-
responses_subscription: str,
133+
project: str,
134+
location: str,
135+
dataflow_job: str,
136+
messages_topic: str,
137+
responses_subscription: str,
135138
) -> None:
136139
print(f"Waiting for the Dataflow workers to start: {dataflow_job}")
137140
conftest.wait_until(
138-
lambda: conftest.dataflow_num_workers(project, location, dataflow_job)
139-
> 0,
141+
lambda: conftest.dataflow_num_workers(project, location, dataflow_job) > 0,
140142
"workers are running",
141143
)
142-
num_workers = conftest.dataflow_num_workers(project, location,
143-
dataflow_job)
144+
num_workers = conftest.dataflow_num_workers(project, location, dataflow_job)
144145
print(f"Dataflow job num_workers: {num_workers}")
145146

146147
messages = ["This is a test for a Python sample."]

dataflow/gemma/noxfile_config.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818
# You can opt out from the test for specific Python versions.
1919
# The Python version used is defined by the Dockerfile and the job
2020
# submission enviornment must match.
21-
# Note: Docker-based sample, testing only against version specified in Dockerfile (3.14)
22-
"ignored_versions": ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"],
23-
"envs": {
24-
"PYTHONPATH": ".."
25-
},
21+
# Note: Docker-based sample, testing only against version specified in Dockerfile (3.11)
22+
"ignored_versions": ["3.8", "3.9", "3.10"],
23+
"envs": {"PYTHONPATH": ".."},
2624
}
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
google-cloud-aiplatform==1.49.0
2-
google-cloud-dataflow-client==0.8.10
3-
google-cloud-storage==2.16.0
4-
pytest==9.0.3; python_version >= "3.10"
5-
pytest-timeout==2.3.1
1+
google-cloud-aiplatform==1.157.0
2+
google-cloud-dataflow-client==0.14.0
3+
google-cloud-storage==3.12.0
4+
pytest==9.0.3
5+
pytest-timeout==2.4.0

dataflow/gemma/requirements.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
apache_beam[gcp]==2.54.0
2-
protobuf==4.25.0
3-
keras_nlp==0.8.2
4-
keras==3.0.5
1+
protobuf==6.33.6
2+
apache_beam[gcp]==2.74.0
3+
keras==3.14.1
4+
keras_nlp==0.29.1
5+
pyOpenSSL==25.3.0

0 commit comments

Comments
 (0)