Skip to content

Commit b47148d

Browse files
committed
MLFlow and Github Actions
1 parent da929c5 commit b47148d

3 files changed

Lines changed: 167 additions & 35 deletions

File tree

.github/workflows/github-actions-basic.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ on: [push]
55
jobs:
66
build:
77

8-
runs-on: ubuntu-latest
8+
runs-on: windows-latest
99

1010
strategy:
1111
matrix:

Chapter02/mlewp-chapter02.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@ dependencies:
77
- libffi=3.4.2
88
- libsqlite=3.40.0
99
- libzlib=1.2.13
10-
- ncurses=6.3
1110
- openssl=3.1.0
1211
- pip=23.0.1
1312
- python=3.10.9
14-
- readline=8.1.2
1513
- setuptools=67.6.0
1614
- tk=8.6.12
1715
- tzdata=2022g
@@ -170,7 +168,5 @@ dependencies:
170168
- webcolors==1.12
171169
- webencodings==0.5.1
172170
- websocket-client==1.5.1
173-
- xattr==0.10.1
174171
- yarl==1.8.2
175172
- zipp==3.15.0
176-
prefix: /opt/homebrew/Caskroom/miniforge/base/envs/mlewp-chapter02
Lines changed: 166 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,153 @@
1-
#BASED ON EXAMPLE FROM MLFLOW DOCS
1+
# BASED ON EXAMPLE FROM MLFLOW DOCS
22
# https://github.com/mlflow/mlflow/blob/master/examples/prophet/train.py
33
import pandas as pd
4-
from fbprophet import Prophet
4+
from prophet import Prophet
55

6-
from fbprophet.diagnostics import cross_validation
7-
from fbprophet.diagnostics import performance_metrics
6+
from prophet.diagnostics import cross_validation
7+
from prophet.diagnostics import performance_metrics
88

99
import mlflow
1010
import mlflow.pyfunc
11+
from typing import Any
12+
from mlflow.models.signature import infer_signature
1113

1214
import logging
15+
1316
logging.basicConfig(level=logging.WARN)
1417
logger = logging.getLogger(__name__)
1518

19+
1620
class FbProphetWrapper(mlflow.pyfunc.PythonModel):
1721
def __init__(self, model):
22+
"""
23+
Initialize a FbProphetWrapper instance.
24+
25+
Parameters
26+
----------
27+
model : Prophet
28+
The Prophet model to be wrapped.
29+
30+
Returns
31+
-------
32+
None
33+
"""
1834
self.model = model
1935
super().__init__()
2036

2137
def load_context(self, context):
22-
from fbprophet import Prophet
23-
24-
return
25-
26-
def predict(self, context, model_input):
27-
future = self.model.make_future_dataframe(periods=model_input["periods"][0])
28-
return self.model.predict(future)
38+
"""
39+
Load the model context from the MLflow server.
40+
41+
Parameters
42+
----------
43+
context : mlflow.pyfunc.ModelContext
44+
The model context to be loaded.
45+
46+
Returns
47+
-------
48+
None
49+
50+
Notes
51+
-----
52+
This function is called by the MLflow server to load the model context
53+
from the server. The context is loaded into the model object, and
54+
can be accessed by the model object's methods.
55+
"""
56+
# context is provided by MLflow at load time. If you need to
57+
# restore external artifacts, do it here. We keep this simple
58+
# because the Prophet model object is already serialized by MLflow
59+
# when logging the pyfunc wrapper.
60+
return None
61+
62+
def predict(self, context: Any, model_input: pd.DataFrame) -> pd.DataFrame:
63+
"""Make predictions using the wrapped Prophet model.
64+
65+
Parameters
66+
----------
67+
context : mlflow.pyfunc.ModelContext
68+
MLflow model context (provided at prediction time).
69+
model_input : pd.DataFrame
70+
Input dataframe. Two supported modes:
71+
- If it contains a column named "periods", the first value
72+
is taken as an integer number of periods to forecast and
73+
a future dataframe is created via Prophet's
74+
`make_future_dataframe`.
75+
- If it contains a column named "ds", it is treated as a
76+
dataframe of datestamps to run predict() on directly.
77+
78+
Returns
79+
-------
80+
pd.DataFrame
81+
The DataFrame returned by Prophet's `predict` method.
82+
"""
83+
# If user passed periods (e.g., during local scoring tests), create future df
84+
try:
85+
if isinstance(model_input, dict):
86+
# backward compatible: accept dict with 'periods'
87+
periods = int(model_input.get("periods", [0])[0])
88+
future = self.model.make_future_dataframe(periods=periods)
89+
return self.model.predict(future)
90+
91+
if "periods" in model_input.columns:
92+
periods = (
93+
int(model_input.loc[0, "periods"]) if not model_input.empty else 0
94+
)
95+
future = self.model.make_future_dataframe(periods=periods)
96+
return self.model.predict(future)
97+
98+
# If input contains datestamps (ds), predict for those rows
99+
if "ds" in model_input.columns:
100+
# Ensure ds is datetime
101+
mi = model_input.copy()
102+
mi["ds"] = pd.to_datetime(mi["ds"], errors="coerce")
103+
mi = mi.dropna(subset=["ds"]) # drop invalid dates
104+
return self.model.predict(mi)
105+
106+
# Fallback: try to coerce the whole input to numeric periods
107+
# and create a small future frame
108+
# (this is defensive; prefer explicit 'periods' or 'ds')
109+
periods = 0
110+
if not model_input.empty:
111+
# try to use the first column as integer periods
112+
try:
113+
periods = int(
114+
pd.to_numeric(model_input.iloc[0, 0], errors="coerce")
115+
)
116+
except Exception:
117+
periods = 0
118+
future = self.model.make_future_dataframe(periods=periods)
119+
return self.model.predict(future)
120+
except Exception as e:
121+
# Raise the exception so MLflow can capture the failure; include
122+
# a helpful message for debugging.
123+
raise RuntimeError(f"Error in FbProphetWrapper.predict: {e}") from e
124+
125+
126+
seasonality = {"yearly": True, "weekly": True, "daily": True}
29127

30128

31-
seasonality = {
32-
'yearly': True,
33-
'weekly': True,
34-
'daily': True
35-
}
36-
37129
def train_predict(df_all_data, df_all_train_index, seasonality_params=seasonality):
130+
"""
131+
Train a Prophet model on the given data and log the model and its metrics to MLflow.
132+
133+
Parameters
134+
----------
135+
df_all_data : pandas.DataFrame
136+
The dataframe containing all the data to be split into train and test sets.
137+
df_all_train_index : int
138+
The index to split the dataframe into train and test sets.
139+
seasonality_params : dict, optional
140+
A dictionary containing the yearly, weekly, and daily seasonality parameters for the Prophet model.
141+
142+
Returns
143+
-------
144+
tuple[pandas.DataFrame, pandas.DataFrame, pandas.DataFrame]
145+
A tuple containing the predicted values, the train dataframe, and the test dataframe.
146+
147+
Notes
148+
-----
149+
This function will log the model and its metrics to MLflow. The model will be logged with the name "model" and the metrics will be logged with the name "rmse".
150+
"""
38151
# grab split data
39152
df_train = df_all_data.copy().iloc[0:df_all_train_index]
40153
df_test = df_all_data.copy().iloc[df_all_train_index:]
@@ -43,15 +156,17 @@ def train_predict(df_all_data, df_all_train_index, seasonality_params=seasonalit
43156
with mlflow.start_run():
44157
# create Prophet model
45158
model = Prophet(
46-
yearly_seasonality=seasonality_params['yearly'],
47-
weekly_seasonality=seasonality_params['weekly'],
48-
daily_seasonality=seasonality_params['daily']
159+
yearly_seasonality=seasonality_params["yearly"],
160+
weekly_seasonality=seasonality_params["weekly"],
161+
daily_seasonality=seasonality_params["daily"],
49162
)
50163
# train and predict
51164
model.fit(df_train)
52165

53166
# Evaluate Metrics
54-
df_cv = cross_validation(model, initial="730 days", period="180 days", horizon="365 days")
167+
df_cv = cross_validation(
168+
model, initial="540 days", period="180 days", horizon="90 days"
169+
)
55170
df_p = performance_metrics(df_cv)
56171

57172
# Print out metrics
@@ -61,7 +176,30 @@ def train_predict(df_all_data, df_all_train_index, seasonality_params=seasonalit
61176
# Log parameter, metrics, and model to MLflow
62177
mlflow.log_metric("rmse", df_p.loc[0, "rmse"])
63178

64-
mlflow.pyfunc.log_model("model", python_model=FbProphetWrapper(model))
179+
# Try to infer a model signature and attach an input example so MLflow
180+
# does not emit warnings about missing signatures and examples.
181+
try:
182+
if "ds" in df_test.columns:
183+
example_input = df_test[["ds"]].head(10)
184+
else:
185+
example_input = pd.DataFrame({"periods": [10]})
186+
187+
try:
188+
example_output = model.predict(example_input)
189+
except Exception:
190+
example_output = model.predict(df_test.head(10))
191+
192+
signature = infer_signature(example_input, example_output)
193+
mlflow.pyfunc.log_model(
194+
"model",
195+
python_model=FbProphetWrapper(model),
196+
signature=signature,
197+
input_example=example_input,
198+
)
199+
except Exception:
200+
# If signature inference fails, log the model without signature to avoid aborting training
201+
mlflow.pyfunc.log_model("model", python_model=FbProphetWrapper(model))
202+
65203
print(
66204
"Logged model with URI: runs:/{run_id}/model".format(
67205
run_id=mlflow.active_run().info.run_id
@@ -74,14 +212,12 @@ def train_predict(df_all_data, df_all_train_index, seasonality_params=seasonalit
74212

75213
if __name__ == "__main__":
76214
# Read in Data
77-
df = pd.read_csv('../../Chapter01/forecasting-api/data/demand-forecasting-kernels-only/train.csv')
78-
df.rename(columns={'date': 'ds', 'sales': 'y'}, inplace=True)
215+
df = pd.read_csv("../../Chapter01/forecasting/rossman_store_data/train.csv")
216+
df.rename(columns={"Date": "ds", "Sales": "y"}, inplace=True)
79217
# Filter out store and item 1
80-
df_store1_item1 = df[(df['store'] == 1) & (df['item'] == 1)].reset_index(drop=True)
218+
df_store1 = df[(df['Store'] == 1)].reset_index(drop=True)
81219

82-
train_index = int(0.8 * df_store1_item1.shape[0])
220+
train_index = int(0.8 * df_store1.shape[0])
83221
predicted, df_train, df_test = train_predict(
84-
df_all_data=df_store1_item1,
85-
df_all_train_index=train_index,
86-
seasonality_params=seasonality
87-
)
222+
df_all_data=df_store1, df_all_train_index=train_index, seasonality_params=seasonality
223+
)

0 commit comments

Comments
 (0)