Skip to content

Commit 47be7d9

Browse files
authored
Enable known_covariates in TimeSeriesCloudPredictor.predict_real_time (#181)
1 parent e8071ef commit 47be7d9

2 files changed

Lines changed: 15 additions & 5 deletions

File tree

src/autogluon/cloud/backend/timeseries_sagemaker_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def predict_real_time(
166166
static_features: Optional[pd.DataFrame]
167167
An optional data frame describing the metadata attributes of individual items in the item index.
168168
For more detail, please refer to `TimeSeriesDataFrame` documentation:
169-
https://auto.gluon.ai/stable/api/autogluon.predictor.html#timeseriesdataframe
169+
https://auto.gluon.ai/stable/api/autogluon.timeseries.TimeSeriesDataFrame.html
170170
target: str
171171
Name of column that contains the target values to forecast
172172
accept: str, default = application/x-parquet
@@ -225,7 +225,7 @@ def predict(
225225
static_features: Optional[Union[str, pd.DataFrame]]
226226
An optional data frame describing the metadata attributes of individual items in the item index.
227227
For more detail, please refer to `TimeSeriesDataFrame` documentation:
228-
https://auto.gluon.ai/stable/api/autogluon.predictor.html#timeseriesdataframe
228+
https://auto.gluon.ai/stable/api/autogluon.timeseries.TimeSeriesDataFrame.html
229229
target: str
230230
Name of column that contains the target values to forecast
231231
kwargs:

src/autogluon/cloud/predictor/timeseries_cloud_predictor.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def fit(
8080
static_features: Optional[pd.DataFrame]
8181
An optional data frame describing the metadata attributes of individual items in the item index.
8282
For more detail, please refer to `TimeSeriesDataFrame` documentation:
83-
https://auto.gluon.ai/stable/api/autogluon.predictor.html#timeseriesdataframe
83+
https://auto.gluon.ai/stable/api/autogluon.timeseries.TimeSeriesDataFrame.html
8484
framework_version: str, default = `latest`
8585
Training container version of autogluon.
8686
If `latest`, will use the latest available container version.
@@ -159,6 +159,7 @@ def predict_real_time(
159159
self,
160160
test_data: Union[str, pd.DataFrame],
161161
static_features: Optional[Union[str, pd.DataFrame]] = None,
162+
known_covariates: Optional[pd.DataFrame] = None,
162163
accept: str = "application/x-parquet",
163164
**kwargs,
164165
) -> pd.DataFrame:
@@ -175,7 +176,12 @@ def predict_real_time(
175176
static_features: Optional[pd.DataFrame]
176177
An optional data frame describing the metadata attributes of individual items in the item index.
177178
For more detail, please refer to `TimeSeriesDataFrame` documentation:
178-
https://auto.gluon.ai/stable/api/autogluon.predictor.html#timeseriesdataframe
179+
https://auto.gluon.ai/stable/api/autogluon.timeseries.TimeSeriesDataFrame.html
180+
known_covariates : Optional[pd.DataFrame]
181+
If ``known_covariates_names`` were specified when creating the predictor, it is necessary to provide the
182+
values of the known covariates for each time series during the forecast horizon.
183+
For more details, please refer to the `TimeSeriesPredictor.predictor` documentation:
184+
https://auto.gluon.ai/stable/api/autogluon.timeseries.TimeSeriesPredictor.predict.html
179185
accept: str, default = application/x-parquet
180186
Type of accept output content.
181187
Valid options are application/x-parquet, text/csv, application/json
@@ -198,6 +204,7 @@ def predict_real_time(
198204
target=self.target_column,
199205
static_features=static_features,
200206
accept=accept,
207+
inference_kwargs=dict(known_covariates=known_covariates, **kwargs),
201208
)
202209

203210
def predict_proba_real_time(self, **kwargs) -> pd.DataFrame:
@@ -224,6 +231,9 @@ def predict(
224231
This method would first create a AutoGluonSagemakerInferenceModel with the trained predictor,
225232
then create a transformer with it, and call transform in the end.
226233
234+
Note that batch prediction with `known_covariates` is currently not supported. Please use `predict_real_time`
235+
to predict with `known_covariates` instead.
236+
227237
Parameters
228238
----------
229239
test_data: str
@@ -232,7 +242,7 @@ def predict(
232242
static_features: Optional[Union[str, pd.DataFrame]]
233243
An optional data frame describing the metadata attributes of individual items in the item index.
234244
For more detail, please refer to `TimeSeriesDataFrame` documentation:
235-
https://auto.gluon.ai/stable/api/autogluon.predictor.html#timeseriesdataframe
245+
https://auto.gluon.ai/stable/api/autogluon.timeseries.TimeSeriesDataFrame.html
236246
target: str
237247
Name of column that contains the target values to forecast
238248
predictor_path: str

0 commit comments

Comments
 (0)