1- #BASED ON EXAMPLE FROM MLFLOW DOCS
1+ # BASED ON EXAMPLE FROM MLFLOW DOCS
22# https://github.com/mlflow/mlflow/blob/master/examples/prophet/train.py
33import pandas as pd
4- from fbprophet import Prophet
4+ from prophet import Prophet
55
6- from fbprophet .diagnostics import cross_validation
7- from fbprophet .diagnostics import performance_metrics
6+ from prophet .diagnostics import cross_validation
7+ from prophet .diagnostics import performance_metrics
88
99import mlflow
1010import mlflow .pyfunc
11+ from typing import Any
12+ from mlflow .models .signature import infer_signature
1113
1214import logging
15+
1316logging .basicConfig (level = logging .WARN )
1417logger = logging .getLogger (__name__ )
1518
19+
1620class FbProphetWrapper (mlflow .pyfunc .PythonModel ):
1721 def __init__ (self , model ):
22+ """
23+ Initialize a FbProphetWrapper instance.
24+
25+ Parameters
26+ ----------
27+ model : Prophet
28+ The Prophet model to be wrapped.
29+
30+ Returns
31+ -------
32+ None
33+ """
1834 self .model = model
1935 super ().__init__ ()
2036
2137 def load_context (self , context ):
22- from fbprophet import Prophet
23-
24- return
25-
26- def predict (self , context , model_input ):
27- future = self .model .make_future_dataframe (periods = model_input ["periods" ][0 ])
28- return self .model .predict (future )
38+ """
39+ Load the model context from the MLflow server.
40+
41+ Parameters
42+ ----------
43+ context : mlflow.pyfunc.ModelContext
44+ The model context to be loaded.
45+
46+ Returns
47+ -------
48+ None
49+
50+ Notes
51+ -----
52+ This function is called by the MLflow server to load the model context
53+ from the server. The context is loaded into the model object, and
54+ can be accessed by the model object's methods.
55+ """
56+ # context is provided by MLflow at load time. If you need to
57+ # restore external artifacts, do it here. We keep this simple
58+ # because the Prophet model object is already serialized by MLflow
59+ # when logging the pyfunc wrapper.
60+ return None
61+
62+ def predict (self , context : Any , model_input : pd .DataFrame ) -> pd .DataFrame :
63+ """Make predictions using the wrapped Prophet model.
64+
65+ Parameters
66+ ----------
67+ context : mlflow.pyfunc.ModelContext
68+ MLflow model context (provided at prediction time).
69+ model_input : pd.DataFrame
70+ Input dataframe. Two supported modes:
71+ - If it contains a column named "periods", the first value
72+ is taken as an integer number of periods to forecast and
73+ a future dataframe is created via Prophet's
74+ `make_future_dataframe`.
75+ - If it contains a column named "ds", it is treated as a
76+ dataframe of datestamps to run predict() on directly.
77+
78+ Returns
79+ -------
80+ pd.DataFrame
81+ The DataFrame returned by Prophet's `predict` method.
82+ """
83+ # If user passed periods (e.g., during local scoring tests), create future df
84+ try :
85+ if isinstance (model_input , dict ):
86+ # backward compatible: accept dict with 'periods'
87+ periods = int (model_input .get ("periods" , [0 ])[0 ])
88+ future = self .model .make_future_dataframe (periods = periods )
89+ return self .model .predict (future )
90+
91+ if "periods" in model_input .columns :
92+ periods = (
93+ int (model_input .loc [0 , "periods" ]) if not model_input .empty else 0
94+ )
95+ future = self .model .make_future_dataframe (periods = periods )
96+ return self .model .predict (future )
97+
98+ # If input contains datestamps (ds), predict for those rows
99+ if "ds" in model_input .columns :
100+ # Ensure ds is datetime
101+ mi = model_input .copy ()
102+ mi ["ds" ] = pd .to_datetime (mi ["ds" ], errors = "coerce" )
103+ mi = mi .dropna (subset = ["ds" ]) # drop invalid dates
104+ return self .model .predict (mi )
105+
106+ # Fallback: try to coerce the whole input to numeric periods
107+ # and create a small future frame
108+ # (this is defensive; prefer explicit 'periods' or 'ds')
109+ periods = 0
110+ if not model_input .empty :
111+ # try to use the first column as integer periods
112+ try :
113+ periods = int (
114+ pd .to_numeric (model_input .iloc [0 , 0 ], errors = "coerce" )
115+ )
116+ except Exception :
117+ periods = 0
118+ future = self .model .make_future_dataframe (periods = periods )
119+ return self .model .predict (future )
120+ except Exception as e :
121+ # Raise the exception so MLflow can capture the failure; include
122+ # a helpful message for debugging.
123+ raise RuntimeError (f"Error in FbProphetWrapper.predict: { e } " ) from e
124+
125+
126+ seasonality = {"yearly" : True , "weekly" : True , "daily" : True }
29127
30128
31- seasonality = {
32- 'yearly' : True ,
33- 'weekly' : True ,
34- 'daily' : True
35- }
36-
37129def train_predict (df_all_data , df_all_train_index , seasonality_params = seasonality ):
130+ """
131+ Train a Prophet model on the given data and log the model and its metrics to MLflow.
132+
133+ Parameters
134+ ----------
135+ df_all_data : pandas.DataFrame
136+ The dataframe containing all the data to be split into train and test sets.
137+ df_all_train_index : int
138+ The index to split the dataframe into train and test sets.
139+ seasonality_params : dict, optional
140+ A dictionary containing the yearly, weekly, and daily seasonality parameters for the Prophet model.
141+
142+ Returns
143+ -------
144+ tuple[pandas.DataFrame, pandas.DataFrame, pandas.DataFrame]
145+ A tuple containing the predicted values, the train dataframe, and the test dataframe.
146+
147+ Notes
148+ -----
149+ This function will log the model and its metrics to MLflow. The model will be logged with the name "model" and the metrics will be logged with the name "rmse".
150+ """
38151 # grab split data
39152 df_train = df_all_data .copy ().iloc [0 :df_all_train_index ]
40153 df_test = df_all_data .copy ().iloc [df_all_train_index :]
@@ -43,15 +156,17 @@ def train_predict(df_all_data, df_all_train_index, seasonality_params=seasonalit
43156 with mlflow .start_run ():
44157 # create Prophet model
45158 model = Prophet (
46- yearly_seasonality = seasonality_params [' yearly' ],
47- weekly_seasonality = seasonality_params [' weekly' ],
48- daily_seasonality = seasonality_params [' daily' ]
159+ yearly_seasonality = seasonality_params [" yearly" ],
160+ weekly_seasonality = seasonality_params [" weekly" ],
161+ daily_seasonality = seasonality_params [" daily" ],
49162 )
50163 # train and predict
51164 model .fit (df_train )
52165
53166 # Evaluate Metrics
54- df_cv = cross_validation (model , initial = "730 days" , period = "180 days" , horizon = "365 days" )
167+ df_cv = cross_validation (
168+ model , initial = "540 days" , period = "180 days" , horizon = "90 days"
169+ )
55170 df_p = performance_metrics (df_cv )
56171
57172 # Print out metrics
@@ -61,7 +176,30 @@ def train_predict(df_all_data, df_all_train_index, seasonality_params=seasonalit
61176 # Log parameter, metrics, and model to MLflow
62177 mlflow .log_metric ("rmse" , df_p .loc [0 , "rmse" ])
63178
64- mlflow .pyfunc .log_model ("model" , python_model = FbProphetWrapper (model ))
179+ # Try to infer a model signature and attach an input example so MLflow
180+ # does not emit warnings about missing signatures and examples.
181+ try :
182+ if "ds" in df_test .columns :
183+ example_input = df_test [["ds" ]].head (10 )
184+ else :
185+ example_input = pd .DataFrame ({"periods" : [10 ]})
186+
187+ try :
188+ example_output = model .predict (example_input )
189+ except Exception :
190+ example_output = model .predict (df_test .head (10 ))
191+
192+ signature = infer_signature (example_input , example_output )
193+ mlflow .pyfunc .log_model (
194+ "model" ,
195+ python_model = FbProphetWrapper (model ),
196+ signature = signature ,
197+ input_example = example_input ,
198+ )
199+ except Exception :
200+ # If signature inference fails, log the model without signature to avoid aborting training
201+ mlflow .pyfunc .log_model ("model" , python_model = FbProphetWrapper (model ))
202+
65203 print (
66204 "Logged model with URI: runs:/{run_id}/model" .format (
67205 run_id = mlflow .active_run ().info .run_id
@@ -74,14 +212,12 @@ def train_predict(df_all_data, df_all_train_index, seasonality_params=seasonalit
74212
75213if __name__ == "__main__" :
76214 # Read in Data
77- df = pd .read_csv (' ../../Chapter01/forecasting-api/data/demand-forecasting-kernels-only/ train.csv' )
78- df .rename (columns = {'date' : 'ds' , 'sales' : 'y' }, inplace = True )
215+ df = pd .read_csv (" ../../Chapter01/forecasting/rossman_store_data/ train.csv" )
216+ df .rename (columns = {"Date" : "ds" , "Sales" : "y" }, inplace = True )
79217 # Filter out store and item 1
80- df_store1_item1 = df [(df ['store' ] == 1 ) & ( df [ 'item ' ] == 1 )].reset_index (drop = True )
218+ df_store1 = df [(df ['Store ' ] == 1 )].reset_index (drop = True )
81219
82- train_index = int (0.8 * df_store1_item1 .shape [0 ])
220+ train_index = int (0.8 * df_store1 .shape [0 ])
83221 predicted , df_train , df_test = train_predict (
84- df_all_data = df_store1_item1 ,
85- df_all_train_index = train_index ,
86- seasonality_params = seasonality
87- )
222+ df_all_data = df_store1 , df_all_train_index = train_index , seasonality_params = seasonality
223+ )
0 commit comments