11"""
22Dataflow streaming pipeline for real-time Iris inference.
3- Reads from Pub/Sub, calls Vertex AI endpoint , writes predictions to BigQuery.
3+ Reads from Pub/Sub, calls FastAPI ML service deployed via Kubeflow , writes predictions to BigQuery.
44"""
55import json
66import logging
77import argparse
88from typing import Any , Dict , List
9+ import requests
10+ import time
911
1012import apache_beam as beam
1113from apache_beam .options .pipeline_options import PipelineOptions
1214from apache_beam .transforms import window
1315from apache_beam .io import ReadFromPubSub , WriteToBigQuery
14- from google .cloud import aiplatform
15- from google .oauth2 import service_account
1616
1717# Constants
1818PROJECT_ID = "deeplearning-sahil"
1919REGION = "us-central1"
2020MODEL_NAME = "Iris-Classifier-XGBoost"
21- ENDPOINT_NAME = "Iris-Classifier-XGBoost "
21+ FASTAPI_SERVICE_NAME = "iris-classifier-xgboost-service "
2222
2323# BigQuery schema for predictions
2424PREDICTION_SCHEMA = {
3030 {'name' : 'timestamp' , 'type' : 'TIMESTAMP' , 'mode' : 'REQUIRED' },
3131 {'name' : 'sample_id' , 'type' : 'INTEGER' , 'mode' : 'REQUIRED' },
3232 {'name' : 'prediction' , 'type' : 'STRING' , 'mode' : 'REQUIRED' },
33- {'name' : 'prediction_confidence' , 'type' : 'FLOAT' , 'mode' : 'NULLABLE' },
3433 {'name' : 'prediction_timestamp' , 'type' : 'TIMESTAMP' , 'mode' : 'REQUIRED' },
35- {'name' : 'model_endpoint' , 'type' : 'STRING' , 'mode' : 'REQUIRED' },
36- {'name' : 'processing_time' , 'type' : 'FLOAT' , 'mode' : 'NULLABLE' }
34+ {'name' : 'model_service' , 'type' : 'STRING' , 'mode' : 'REQUIRED' },
35+ {'name' : 'processing_time' , 'type' : 'FLOAT' , 'mode' : 'NULLABLE' },
36+ {'name' : 'dataflow_processing_time' , 'type' : 'TIMESTAMP' , 'mode' : 'REQUIRED' },
3737 ]
3838}
3939
@@ -57,75 +57,44 @@ def process(self, element):
5757 logging .error (f"Error parsing message: { e } , message: { element } " )
5858
5959
60- class CallVertexAIEndpoint (beam .DoFn ):
61- """Call Vertex AI model endpoint for inference."""
60+ class CallFastAPIService (beam .DoFn ):
61+ """Call FastAPI ML service for inference."""
6262
63- def __init__ (self , project : str , region : str , endpoint_name : str ):
64- self .project = project
65- self .region = region
66- self .endpoint_name = endpoint_name
67- self .client = None
68- self .endpoint = None
69-
70- def setup (self ):
71- """Initialize Vertex AI client."""
72- aiplatform .init (project = self .project , location = self .region )
73-
74- # Get the endpoint
75- endpoints = aiplatform .Endpoint .list (
76- filter = f'display_name="{ self .endpoint_name } "'
77- )
78-
79- if endpoints :
80- # If multiple endpoints exist with same name, prioritize by:
81- # 1. Most recently created (newest first)
82- # 2. Then by resource name (for consistency)
83- sorted_endpoints = sorted (
84- endpoints ,
85- key = lambda ep : (ep .create_time , ep .resource_name ),
86- reverse = True
87- )
88-
89- self .endpoint = sorted_endpoints [0 ]
90-
91- else :
92- raise RuntimeError (f"Endpoint '{ self .endpoint_name } ' not found" )
63+ def __init__ (self , service_url : str ):
64+ self .service_url = service_url
65+ self .predict_url = f"{ service_url } /predict"
9366
9467 def process (self , element ):
9568 import time
9669 from datetime import datetime
70+ import requests
9771
9872 start_time = time .time ()
9973
10074 try :
101- # Prepare features for prediction
102- features = [
103- element ['sepal_length' ],
104- element ['sepal_width' ],
105- element ['petal_length' ],
106- element ['petal_width' ]
107- ]
75+ # Prepare payload for FastAPI
76+ payload = {
77+ "instances" : [{
78+ "SepalLengthCm" : element ['sepal_length' ],
79+ "SepalWidthCm" : element ['sepal_width' ],
80+ "PetalLengthCm" : element ['petal_length' ],
81+ "PetalWidthCm" : element ['petal_width' ]
82+ }]
83+ }
10884
109- # Call the endpoint
110- predictions = self .endpoint .predict (instances = [features ])
85+ # Call FastAPI service
86+ response = requests .post (self .predict_url , json = payload , timeout = 30 )
87+ response .raise_for_status ()
11188
112- # Extract prediction result
113- prediction_result = predictions .predictions [0 ]
114-
115- logging .info (f"Prediction result: { prediction_result } " )
89+ # Parse response
90+ result_data = response .json ()
91+ predictions = result_data .get ('predictions' , [])
11692
117- # Handle different prediction formats
118- if isinstance (prediction_result , list ):
119- predicted_class = prediction_result [0 ]
120- confidence = max (prediction_result ) if len (prediction_result ) > 1 else None
93+ if predictions :
94+ prediction_result = predictions [0 ]
95+ predicted_class = str (prediction_result .get ('prediction' , 'unknown' ))
12196 else :
122- predicted_class = str (prediction_result )
123- confidence = None
124-
125- # Map numeric prediction to class name if needed
126- class_mapping = {0 : 'setosa' , 1 : 'versicolor' , 2 : 'virginica' }
127- if str (predicted_class ).isdigit ():
128- predicted_class = class_mapping .get (int (predicted_class ), str (predicted_class ))
97+ predicted_class = 'unknown'
12998
13099 processing_time = time .time () - start_time
131100
@@ -137,18 +106,17 @@ def process(self, element):
137106 'petal_width' : element ['petal_width' ],
138107 'timestamp' : element .get ('timestamp' , datetime .utcnow ().isoformat ()),
139108 'sample_id' : element .get ('sample_id' , 0 ),
140- 'prediction' : str (predicted_class ),
141- 'prediction_confidence' : confidence ,
109+ 'prediction' : predicted_class ,
142110 'prediction_timestamp' : datetime .utcnow ().isoformat (),
143- 'model_endpoint ' : f" { self .project } / { self . region } / { self . endpoint_name } " ,
111+ 'model_service ' : self .service_url ,
144112 'processing_time' : processing_time
145113 }
146114
147115 logging .info (f"Prediction for sample { element .get ('sample_id' )} : { predicted_class } " )
148116 yield result
149117
150118 except Exception as e :
151- logging .error (f"Error calling endpoint : { e } , element: { element } " )
119+ logging .error (f"Error calling FastAPI service : { e } , element: { element } " )
152120 # Yield error record for monitoring
153121 yield {
154122 'sepal_length' : element .get ('sepal_length' , 0.0 ),
@@ -158,9 +126,8 @@ def process(self, element):
158126 'timestamp' : element .get ('timestamp' , datetime .utcnow ().isoformat ()),
159127 'sample_id' : element .get ('sample_id' , 0 ),
160128 'prediction' : 'ERROR' ,
161- 'prediction_confidence' : None ,
162129 'prediction_timestamp' : datetime .utcnow ().isoformat (),
163- 'model_endpoint ' : f"ERROR: { str (e )} " ,
130+ 'model_service ' : f"ERROR: { str (e )} " ,
164131 'processing_time' : time .time () - start_time
165132 }
166133
@@ -173,7 +140,6 @@ def process(self, element):
173140
174141 # Add additional metadata
175142 element ['dataflow_processing_time' ] = datetime .utcnow ().isoformat ()
176- element ['pipeline_version' ] = '1.0.0'
177143
178144 yield element
179145
@@ -203,9 +169,9 @@ def run_pipeline(argv=None):
203169 help = 'GCP Region'
204170 )
205171 parser .add_argument (
206- '--endpoint_name ' ,
172+ '--service_url ' ,
207173 required = True ,
208- help = 'Vertex AI endpoint name '
174+ help = 'FastAPI service URL '
209175 )
210176
211177 known_args , pipeline_args = parser .parse_known_args (argv )
@@ -227,11 +193,8 @@ def run_pipeline(argv=None):
227193 pipeline
228194 | 'Read from Pub/Sub' >> ReadFromPubSub (topic = known_args .input_topic )
229195 | 'Parse JSON' >> beam .ParDo (ParsePubSubMessage ())
230- | 'Add Window' >> beam .WindowInto (window .FixedWindows (60 )) # 1-minute windows
231- | 'Call Vertex AI' >> beam .ParDo (CallVertexAIEndpoint (
232- known_args .project_id ,
233- known_args .region ,
234- known_args .endpoint_name ))
196+ | 'Call FastAPI Service' >> beam .ParDo (CallFastAPIService (
197+ known_args .service_url ))
235198 | 'Add Metadata' >> beam .ParDo (AddProcessingMetadata ())
236199 | 'Write to BigQuery' >> WriteToBigQuery (
237200 table = known_args .output_table ,
0 commit comments