Skip to content

Commit 6262fee

Browse files
committed
inference with fastapi
1 parent 92dca5c commit 6262fee

5 files changed

Lines changed: 80 additions & 124 deletions

File tree

deploy_dataflow_streaming.sh

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22

3-
# Deploy Dataflow streaming job for real-time Iris inference
3+
# Deploy Dataflow streaming job for real-time Iris inference using FastAPI service
44
set -e
55

66
# Configuration
@@ -12,9 +12,10 @@ OUTPUT_TABLE="$PROJECT_ID:ml_dataset.iris_predictions_streaming"
1212
TEMP_LOCATION="gs://sb-vertex/temp"
1313
STAGING_LOCATION="gs://sb-vertex/staging"
1414
SERVICE_ACCOUNT="kfp-mlops@deeplearning-sahil.iam.gserviceaccount.com"
15-
ENDPOINT_NAME="Iris-Classifier-XGBoost"
15+
SERVICE_URL="https://iris-classifier-xgboost-service-zoxyfmo73q-uc.a.run.app"
1616

17-
echo "Deploying Dataflow streaming job for real-time inference..."
17+
echo "Deploying Dataflow streaming job for real-time inference using FastAPI service..."
18+
echo "Note: Update SERVICE_URL with the actual Cloud Run service URL after deployment"
1819

1920
# Run the Dataflow job
2021
echo "Starting Dataflow streaming job: $JOB_NAME"
@@ -23,7 +24,7 @@ python src/ml_pipelines_kfp/dataflow/iris_streaming_pipeline.py \
2324
--output_table $OUTPUT_TABLE \
2425
--project_id $PROJECT_ID \
2526
--region $REGION \
26-
--endpoint_name $ENDPOINT_NAME \
27+
--service_url $SERVICE_URL \
2728
--runner DataflowRunner \
2829
--job_name $JOB_NAME \
2930
--temp_location $TEMP_LOCATION \

pipeline.yaml

Lines changed: 34 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,7 @@ deploymentSpec:
257257
\ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' &&\
258258
\ python3 -m pip install --quiet --no-warn-script-location 'google-cloud-aiplatform>=1.59.0'\
259259
\ 'google-cloud-run>=0.10.0' 'google-cloud-storage>=2.10.0' 'requests>=2.31.0'\
260-
\ 'joblib>=1.4.2' 'scikit-learn>=1.3.0' 'pandas>=2.0.0' 'numpy>=1.24.0'\
261-
\ 'grpcio-status>=1.62.3' && \"$0\" \"$@\"\n"
260+
\ 'joblib>=1.4.2' 'grpcio-status>=1.62.3' && \"$0\" \"$@\"\n"
262261
- sh
263262
- -ec
264263
- 'program_path=$(mktemp -d)
@@ -286,17 +285,15 @@ deploymentSpec:
286285
\ credentials\n client = aiplatform_v1.ModelServiceClient(\n credentials=credentials,\n\
287286
\ client_options={\"api_endpoint\": f\"{location}-aiplatform.googleapis.com\"\
288287
}\n )\n\n print(f\"Searching for blessed model with name: {model_name}\"\
289-
)\n\n # Use the high-level aiplatform library to list all model versions\n\
290-
\ # models = aiplatform.Model.list(filter=f\"display_name={model_name}\"\
291-
)\n # blessed_model = None\n\n request = {\n \"parent\"\
292-
: f\"projects/{project_id}/locations/{location}\",\n \"filter\"\
293-
: f\"display_name={model_name}\"\n }\n\n models = list(client.list_models(request=request))\n\
294-
\ blessed_model = None\n\n print(f\"Found {len(models)} model versions\
295-
\ with name {model_name}\")\n\n # Search through all model versions (each\
296-
\ item in models is already a version)\n for parent_model in models:\n\
297-
\ print(f\"Checking parent model: {parent_model.name}\")\n\n \
298-
\ # List all versions of this model\n versions_request = {\"name\"\
299-
: parent_model.name}\n versions = list(client.list_model_versions(request=versions_request))\n\
288+
)\n\n request = {\n \"parent\": f\"projects/{project_id}/locations/{location}\"\
289+
,\n \"filter\": f\"display_name={model_name}\"\n }\n\n\
290+
\ models = list(client.list_models(request=request))\n blessed_model\
291+
\ = None\n\n print(f\"Found {len(models)} model versions with name {model_name}\"\
292+
)\n\n # Search through all model versions (each item in models is already\
293+
\ a version)\n for parent_model in models:\n print(f\"Checking\
294+
\ parent model: {parent_model.name}\")\n\n # List all versions of\
295+
\ this model\n versions_request = {\"name\": parent_model.name}\n\
296+
\ versions = list(client.list_model_versions(request=versions_request))\n\
300297
\n print(f\"Found {len(versions)} versions for this model\")\n\n\
301298
\ for version in versions:\n print(f\"Version {version.version_id}:\
302299
\ Aliases = {list(version.version_aliases)}\")\n if \"blessed\"\
@@ -308,34 +305,29 @@ deploymentSpec:
308305
\ ValueError(f\"No blessed version found for model {model_name}. Available\
309306
\ versions: {available_versions}\")\n\n print(f\"Found blessed model:\
310307
\ {blessed_model.name}\")\n print(f\"Model URI: {blessed_model.artifact_uri}\"\
311-
)\n\n # 2. Download joblib model from blessed version\n gcs_uri =\
312-
\ blessed_model.artifact_uri\n if not gcs_uri.startswith('gs://'):\n\
313-
\ raise ValueError(f\"Expected GCS URI, got: {gcs_uri}\")\n\n \
314-
\ bucket_name = gcs_uri.replace('gs://', '').split('/')[0]\n model_path\
315-
\ = '/'.join(gcs_uri.replace('gs://', '').split('/')[1:])\n\n print(f\"\
316-
Downloading model from gs://{bucket_name}/{model_path}\")\n\n storage_client\
317-
\ = storage.Client()\n bucket = storage_client.bucket(bucket_name)\n\n\
318-
\ # Download and validate the model\n model_blob_path = f\"{model_path}/model.joblib\"\
308+
)\n\n # Download joblib model from blessed version\n gcs_uri = blessed_model.artifact_uri\n\
309+
\ if not gcs_uri.startswith('gs://'):\n raise ValueError(f\"Expected\
310+
\ GCS URI, got: {gcs_uri}\")\n\n bucket_name = gcs_uri.replace('gs://',\
311+
\ '').split('/')[0]\n model_path = '/'.join(gcs_uri.replace('gs://',\
312+
\ '').split('/')[1:])\n\n print(f\"Downloading model from gs://{bucket_name}/{model_path}\"\
313+
)\n\n storage_client = storage.Client()\n bucket = storage_client.bucket(bucket_name)\n\
314+
\n # Download and validate the model\n model_blob_path = f\"{model_path}/model.joblib\"\
319315
\n blob = bucket.blob(model_blob_path)\n\n if not blob.exists():\n\
320316
\ raise ValueError(f\"Model file not found at gs://{bucket_name}/{model_blob_path}\"\
321317
)\n\n with tempfile.NamedTemporaryFile(suffix='.joblib', delete=False)\
322318
\ as temp_file:\n blob.download_to_filename(temp_file.name)\n \
323319
\ local_model_path = temp_file.name\n\n print(f\"Downloaded model\
324-
\ to: {local_model_path}\")\n\n # 3. Validate model can be loaded\n \
325-
\ try:\n model_obj = joblib.load(local_model_path)\n print(f\"\
326-
Model type: {type(model_obj)}\")\n print(f\"Model validation successful\"\
327-
)\n except Exception as e:\n os.unlink(local_model_path)\n \
328-
\ raise ValueError(f\"Model validation failed: {e}\")\n\n # 4. Copy\
329-
\ model to standard deployment location\n deployment_model_path = f\"\
330-
deployed-models/{service_name}/model.joblib\"\n deployment_blob = bucket.blob(deployment_model_path)\n\
331-
\n print(f\"Copying model to deployment location: gs://{bucket_name}/{deployment_model_path}\"\
320+
\ to: {local_model_path}\")\n\n # Copy model to standard deployment location\n\
321+
\ deployment_model_path = f\"deployed-models/{service_name}/model.joblib\"\
322+
\n deployment_blob = bucket.blob(deployment_model_path)\n\n print(f\"\
323+
Copying model to deployment location: gs://{bucket_name}/{deployment_model_path}\"\
332324
)\n deployment_blob.upload_from_filename(local_model_path)\n\n model_gcs_path\
333325
\ = f\"gs://{bucket_name}/{deployment_model_path}\"\n print(f\"Model\
334-
\ available at: {model_gcs_path}\")\n\n # 5. Deploy to Cloud Run using\
335-
\ pre-built generic image\n print(f\"Deploying to Cloud Run service:\
336-
\ {service_name}\")\n\n run_client = run_v2.ServicesClient()\n\n #\
337-
\ Use pre-built generic FastAPI image from CI/CD\n generic_image = fastapi_image_name\n\
338-
\n service_config = {\n \"parent\": f\"projects/{project_id}/locations/{location}\"\
326+
\ available at: {model_gcs_path}\")\n\n # Deploy to Cloud Run using pre-built\
327+
\ generic image\n print(f\"Deploying to Cloud Run service: {service_name}\"\
328+
)\n\n run_client = run_v2.ServicesClient()\n\n # Use pre-built generic\
329+
\ FastAPI image from CI/CD\n generic_image = fastapi_image_name\n\n \
330+
\ service_config = {\n \"parent\": f\"projects/{project_id}/locations/{location}\"\
339331
,\n \"service_id\": service_name,\n \"service\": {\n \
340332
\ \"template\": {\n \"containers\": [{\n \
341333
\ \"image\": generic_image,\n \"ports\": [{\"\
@@ -370,14 +362,14 @@ deploymentSpec:
370362
\ resource=result.name, # This should be the full resource name\n\
371363
\ policy=policy\n )\n run_client.set_iam_policy(request=iam_request)\n\
372364
\n service_url = result.uri\n print(f\"Service deployed successfully\
373-
\ to: {service_url}\")\n\n # 6. Test deployment\n print(\"\
374-
Testing deployment...\")\n time.sleep(30) # Wait for service to\
375-
\ be ready\n\n test_payload = {\n \"instances\": [\n \
376-
\ {\"SepalLengthCm\": 5.1, \"SepalWidthCm\": 3.5, \"PetalLengthCm\"\
377-
: 1.4, \"PetalWidthCm\": 0.2}\n ]\n }\n\n try:\n\
378-
\ # Test health endpoint first\n health_response =\
379-
\ requests.get(f\"{service_url}/health\", timeout=30)\n print(f\"\
380-
Health check status: {health_response.status_code}\")\n if health_response.status_code\
365+
\ to: {service_url}\")\n\n # Test deployment\n print(\"Testing\
366+
\ deployment...\")\n time.sleep(30) # Wait for service to be ready\n\
367+
\n test_payload = {\n \"instances\": [\n \
368+
\ {\"SepalLengthCm\": 5.1, \"SepalWidthCm\": 3.5, \"PetalLengthCm\": 1.4,\
369+
\ \"PetalWidthCm\": 0.2}\n ]\n }\n\n try:\n \
370+
\ # Test health endpoint first\n health_response = requests.get(f\"\
371+
{service_url}/health\", timeout=30)\n print(f\"Health check status:\
372+
\ {health_response.status_code}\")\n if health_response.status_code\
381373
\ == 200:\n print(f\"Health check response: {health_response.json()}\"\
382374
)\n\n # Test prediction endpoint\n response = requests.post(\n\
383375
\ f\"{service_url}/predict\", \n json=test_payload,\n\

src/ml_pipelines_kfp/dataflow/iris_streaming_pipeline.py

Lines changed: 40 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
11
"""
22
Dataflow streaming pipeline for real-time Iris inference.
3-
Reads from Pub/Sub, calls Vertex AI endpoint, writes predictions to BigQuery.
3+
Reads from Pub/Sub, calls FastAPI ML service deployed via Kubeflow, writes predictions to BigQuery.
44
"""
55
import json
66
import logging
77
import argparse
88
from typing import Any, Dict, List
9+
import requests
10+
import time
911

1012
import apache_beam as beam
1113
from apache_beam.options.pipeline_options import PipelineOptions
1214
from apache_beam.transforms import window
1315
from apache_beam.io import ReadFromPubSub, WriteToBigQuery
14-
from google.cloud import aiplatform
15-
from google.oauth2 import service_account
1616

1717
# Constants
1818
PROJECT_ID = "deeplearning-sahil"
1919
REGION = "us-central1"
2020
MODEL_NAME = "Iris-Classifier-XGBoost"
21-
ENDPOINT_NAME = "Iris-Classifier-XGBoost"
21+
FASTAPI_SERVICE_NAME = "iris-classifier-xgboost-service"
2222

2323
# BigQuery schema for predictions
2424
PREDICTION_SCHEMA = {
@@ -30,10 +30,10 @@
3030
{'name': 'timestamp', 'type': 'TIMESTAMP', 'mode': 'REQUIRED'},
3131
{'name': 'sample_id', 'type': 'INTEGER', 'mode': 'REQUIRED'},
3232
{'name': 'prediction', 'type': 'STRING', 'mode': 'REQUIRED'},
33-
{'name': 'prediction_confidence', 'type': 'FLOAT', 'mode': 'NULLABLE'},
3433
{'name': 'prediction_timestamp', 'type': 'TIMESTAMP', 'mode': 'REQUIRED'},
35-
{'name': 'model_endpoint', 'type': 'STRING', 'mode': 'REQUIRED'},
36-
{'name': 'processing_time', 'type': 'FLOAT', 'mode': 'NULLABLE'}
34+
{'name': 'model_service', 'type': 'STRING', 'mode': 'REQUIRED'},
35+
{'name': 'processing_time', 'type': 'FLOAT', 'mode': 'NULLABLE'},
36+
{'name': 'dataflow_processing_time', 'type': 'TIMESTAMP', 'mode': 'REQUIRED'},
3737
]
3838
}
3939

@@ -57,75 +57,44 @@ def process(self, element):
5757
logging.error(f"Error parsing message: {e}, message: {element}")
5858

5959

60-
class CallVertexAIEndpoint(beam.DoFn):
61-
"""Call Vertex AI model endpoint for inference."""
60+
class CallFastAPIService(beam.DoFn):
61+
"""Call FastAPI ML service for inference."""
6262

63-
def __init__(self, project: str, region: str, endpoint_name: str):
64-
self.project = project
65-
self.region = region
66-
self.endpoint_name = endpoint_name
67-
self.client = None
68-
self.endpoint = None
69-
70-
def setup(self):
71-
"""Initialize Vertex AI client."""
72-
aiplatform.init(project=self.project, location=self.region)
73-
74-
# Get the endpoint
75-
endpoints = aiplatform.Endpoint.list(
76-
filter=f'display_name="{self.endpoint_name}"'
77-
)
78-
79-
if endpoints:
80-
# If multiple endpoints exist with same name, prioritize by:
81-
# 1. Most recently created (newest first)
82-
# 2. Then by resource name (for consistency)
83-
sorted_endpoints = sorted(
84-
endpoints,
85-
key=lambda ep: (ep.create_time, ep.resource_name),
86-
reverse=True
87-
)
88-
89-
self.endpoint = sorted_endpoints[0]
90-
91-
else:
92-
raise RuntimeError(f"Endpoint '{self.endpoint_name}' not found")
63+
def __init__(self, service_url: str):
64+
self.service_url = service_url
65+
self.predict_url = f"{service_url}/predict"
9366

9467
def process(self, element):
9568
import time
9669
from datetime import datetime
70+
import requests
9771

9872
start_time = time.time()
9973

10074
try:
101-
# Prepare features for prediction
102-
features = [
103-
element['sepal_length'],
104-
element['sepal_width'],
105-
element['petal_length'],
106-
element['petal_width']
107-
]
75+
# Prepare payload for FastAPI
76+
payload = {
77+
"instances": [{
78+
"SepalLengthCm": element['sepal_length'],
79+
"SepalWidthCm": element['sepal_width'],
80+
"PetalLengthCm": element['petal_length'],
81+
"PetalWidthCm": element['petal_width']
82+
}]
83+
}
10884

109-
# Call the endpoint
110-
predictions = self.endpoint.predict(instances=[features])
85+
# Call FastAPI service
86+
response = requests.post(self.predict_url, json=payload, timeout=30)
87+
response.raise_for_status()
11188

112-
# Extract prediction result
113-
prediction_result = predictions.predictions[0]
114-
115-
logging.info(f"Prediction result: {prediction_result}")
89+
# Parse response
90+
result_data = response.json()
91+
predictions = result_data.get('predictions', [])
11692

117-
# Handle different prediction formats
118-
if isinstance(prediction_result, list):
119-
predicted_class = prediction_result[0]
120-
confidence = max(prediction_result) if len(prediction_result) > 1 else None
93+
if predictions:
94+
prediction_result = predictions[0]
95+
predicted_class = str(prediction_result.get('prediction', 'unknown'))
12196
else:
122-
predicted_class = str(prediction_result)
123-
confidence = None
124-
125-
# Map numeric prediction to class name if needed
126-
class_mapping = {0: 'setosa', 1: 'versicolor', 2: 'virginica'}
127-
if str(predicted_class).isdigit():
128-
predicted_class = class_mapping.get(int(predicted_class), str(predicted_class))
97+
predicted_class = 'unknown'
12998

13099
processing_time = time.time() - start_time
131100

@@ -137,18 +106,17 @@ def process(self, element):
137106
'petal_width': element['petal_width'],
138107
'timestamp': element.get('timestamp', datetime.utcnow().isoformat()),
139108
'sample_id': element.get('sample_id', 0),
140-
'prediction': str(predicted_class),
141-
'prediction_confidence': confidence,
109+
'prediction': predicted_class,
142110
'prediction_timestamp': datetime.utcnow().isoformat(),
143-
'model_endpoint': f"{self.project}/{self.region}/{self.endpoint_name}",
111+
'model_service': self.service_url,
144112
'processing_time': processing_time
145113
}
146114

147115
logging.info(f"Prediction for sample {element.get('sample_id')}: {predicted_class}")
148116
yield result
149117

150118
except Exception as e:
151-
logging.error(f"Error calling endpoint: {e}, element: {element}")
119+
logging.error(f"Error calling FastAPI service: {e}, element: {element}")
152120
# Yield error record for monitoring
153121
yield {
154122
'sepal_length': element.get('sepal_length', 0.0),
@@ -158,9 +126,8 @@ def process(self, element):
158126
'timestamp': element.get('timestamp', datetime.utcnow().isoformat()),
159127
'sample_id': element.get('sample_id', 0),
160128
'prediction': 'ERROR',
161-
'prediction_confidence': None,
162129
'prediction_timestamp': datetime.utcnow().isoformat(),
163-
'model_endpoint': f"ERROR: {str(e)}",
130+
'model_service': f"ERROR: {str(e)}",
164131
'processing_time': time.time() - start_time
165132
}
166133

@@ -173,7 +140,6 @@ def process(self, element):
173140

174141
# Add additional metadata
175142
element['dataflow_processing_time'] = datetime.utcnow().isoformat()
176-
element['pipeline_version'] = '1.0.0'
177143

178144
yield element
179145

@@ -203,9 +169,9 @@ def run_pipeline(argv=None):
203169
help='GCP Region'
204170
)
205171
parser.add_argument(
206-
'--endpoint_name',
172+
'--service_url',
207173
required=True,
208-
help='Vertex AI endpoint name'
174+
help='FastAPI service URL'
209175
)
210176

211177
known_args, pipeline_args = parser.parse_known_args(argv)
@@ -227,11 +193,8 @@ def run_pipeline(argv=None):
227193
pipeline
228194
| 'Read from Pub/Sub' >> ReadFromPubSub(topic=known_args.input_topic)
229195
| 'Parse JSON' >> beam.ParDo(ParsePubSubMessage())
230-
| 'Add Window' >> beam.WindowInto(window.FixedWindows(60)) # 1-minute windows
231-
| 'Call Vertex AI' >> beam.ParDo(CallVertexAIEndpoint(
232-
known_args.project_id,
233-
known_args.region,
234-
known_args.endpoint_name))
196+
| 'Call FastAPI Service' >> beam.ParDo(CallFastAPIService(
197+
known_args.service_url))
235198
| 'Add Metadata' >> beam.ParDo(AddProcessingMetadata())
236199
| 'Write to BigQuery' >> WriteToBigQuery(
237200
table=known_args.output_table,

src/ml_pipelines_kfp/iris_xgboost/pipelines/components/fastapi/Dockerfile.fastapi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ COPY requirements.fastapi.txt requirements.txt
1515
RUN pip install --no-cache-dir -r requirements.txt
1616

1717
# Copy FastAPI application
18-
COPY fastapi_server_template.py main.py
18+
COPY fastapi_server.py main.py
1919

2020
# Create directory for models
2121
RUN mkdir -p /app/models

src/ml_pipelines_kfp/iris_xgboost/pipelines/components/fastapi/fastapi_server_template.py renamed to src/ml_pipelines_kfp/iris_xgboost/pipelines/components/fastapi/fastapi_server.py

File renamed without changes.

0 commit comments

Comments
 (0)