diff --git a/flexmeasures/data/models/forecasting/pipelines/base.py b/flexmeasures/data/models/forecasting/pipelines/base.py index 9975f49c4b..c2bf17add0 100644 --- a/flexmeasures/data/models/forecasting/pipelines/base.py +++ b/flexmeasures/data/models/forecasting/pipelines/base.py @@ -8,9 +8,10 @@ import pandas as pd from darts import TimeSeries from darts.dataprocessing.transformers import MissingValuesFiller -from flexmeasures.data.models.time_series import Sensor from timely_beliefs import utils as tb_utils +from flexmeasures.data import db +from flexmeasures.data.models.time_series import Sensor from flexmeasures.data.models.forecasting.exceptions import NotEnoughDataException @@ -69,8 +70,10 @@ def __init__( predict_end: datetime | None = None, missing_threshold: float = 1.0, ) -> None: - self.future = future_regressors - self.past = past_regressors + self.future = [ + self._get_attached_sensor(sensor) for sensor in future_regressors + ] + self.past = [self._get_attached_sensor(sensor) for sensor in past_regressors] self.n_steps_to_predict = n_steps_to_predict self.max_forecast_horizon = max_forecast_horizon # rounds up so we get the number of viewpoints, each `forecast_frequency` apart @@ -82,15 +85,15 @@ def __init__( self.save_belief_time = ( save_belief_time # non floored belief time to save forecasts with ) - self.target_sensor = target_sensor - self.target = f"{target_sensor.name} (ID: {target_sensor.id})_target" + self.target_sensor = self._get_attached_sensor(target_sensor) + self.target = f"{self.target_sensor.name} (ID: {self.target_sensor.id})_target" self.future_regressors = [ f"{sensor.name} (ID: {sensor.id})_FR-{idx}" - for idx, sensor in enumerate(future_regressors) + for idx, sensor in enumerate(self.future) ] self.past_regressors = [ f"{sensor.name} (ID: {sensor.id})_PR-{idx}" - for idx, sensor in enumerate(past_regressors) + for idx, sensor in enumerate(self.past) ] self.predict_start = predict_start if predict_start else None self.predict_end = predict_end if predict_end else None @@ -102,6 +105,15 @@ def __init__( self.forecast_frequency = forecast_frequency self.missing_threshold = missing_threshold + @staticmethod + def _get_attached_sensor(sensor: Sensor | int) -> Sensor: + """Reload sensors through the active session to avoid cross-session ORM state.""" + sensor_id = sensor.id if isinstance(sensor, Sensor) else sensor + attached_sensor = db.session.get(Sensor, sensor_id) + if attached_sensor is None: + raise ValueError(f"Could not load sensor with id {sensor_id}.") + return attached_sensor + def load_data_all_beliefs(self) -> pd.DataFrame: """ This function fetches data for each sensor. diff --git a/flexmeasures/data/models/forecasting/pipelines/predict.py b/flexmeasures/data/models/forecasting/pipelines/predict.py index 78fca20420..9e23b55237 100644 --- a/flexmeasures/data/models/forecasting/pipelines/predict.py +++ b/flexmeasures/data/models/forecasting/pipelines/predict.py @@ -13,6 +13,7 @@ from flexmeasures import Sensor, Source from flexmeasures.data import db +from flexmeasures.data.models.data_sources import DataSource from flexmeasures.data.models.forecasting.utils import data_to_bdf from flexmeasures.data.models.forecasting.pipelines.base import BasePipeline from flexmeasures.data.utils import save_to_db @@ -83,7 +84,7 @@ def __init__( self.quantiles = tuple(quantiles) if quantiles else None self.forecast_horizon = np.arange(1, max_forecast_horizon + 1) self.forecast_frequency = forecast_frequency - self.sensor_to_save = sensor_to_save + self.sensor_to_save = self._get_attached_sensor(sensor_to_save) self.predict_start = predict_start self.predict_end = predict_end @@ -92,7 +93,22 @@ def __init__( self.total_forecast_hours = ( self.max_forecast_horizon * self.sensor_resolution.total_seconds() / 3600 ) - self.data_source = data_source + self.data_source = self._get_attached_data_source(data_source) + + @staticmethod + def _get_attached_data_source( + data_source: Source | int | None, + ) -> DataSource | None: + """Reload the prediction source through the active session before saving beliefs.""" + if data_source is None: + return None + source_id = ( + data_source.id if isinstance(data_source, DataSource) else data_source + ) + attached_source = db.session.get(DataSource, source_id) + if attached_source is None: + raise ValueError(f"Could not load data source with id {source_id}.") + return attached_source def load_model(self): """ diff --git a/flexmeasures/data/models/forecasting/pipelines/train_predict.py b/flexmeasures/data/models/forecasting/pipelines/train_predict.py index b21001fa0a..3d19780643 100644 --- a/flexmeasures/data/models/forecasting/pipelines/train_predict.py +++ b/flexmeasures/data/models/forecasting/pipelines/train_predict.py @@ -13,9 +13,11 @@ from flask import current_app from flexmeasures.data import db +from flexmeasures.data.models.data_sources import DataSource from flexmeasures.data.models.forecasting import Forecaster from flexmeasures.data.models.forecasting.pipelines.predict import PredictPipeline from flexmeasures.data.models.forecasting.pipelines.train import TrainPipeline +from flexmeasures.data.models.time_series import Sensor from flexmeasures.data.schemas.forecasting.pipeline import ( ForecasterParametersSchema, TrainPredictPipelineConfigSchema, @@ -23,6 +25,29 @@ from flexmeasures.utils.flexmeasures_inflection import p +def run_train_predict_cycle_job( + config: dict, + parameters: dict, + data_source_id: int, + delete_model: bool, + **cycle_params, +): + """Reconstruct pipeline state inside the worker to avoid pickling ORM objects.""" + pipeline = TrainPredictPipeline(config=config, delete_model=delete_model) + pipeline._parameters = pipeline._parameters_schema.load(parameters) + pipeline._data_source = db.session.get(DataSource, data_source_id) + return pipeline.run_cycle(**cycle_params) + + +def run_train_predict_wrap_up_job(cycle_job_ids: list[str]): + """Log the status of all cycle jobs after completion.""" + connection = current_app.queues["forecasting"].connection + + for index, job_id in enumerate(cycle_job_ids): + status = Job.fetch(job_id, connection=connection).get_status() + logging.info(f"forecasting job-{index}: {job_id} status: {status}") + + class TrainPredictPipeline(Forecaster): __version__ = "1" @@ -63,11 +88,7 @@ def _reattach_if_needed(obj): def run_wrap_up(self, cycle_job_ids: list[str]): """Log the status of all cycle jobs after completion.""" - connection = current_app.queues["forecasting"].connection - - for index, job_id in enumerate(cycle_job_ids): - status = Job.fetch(job_id, connection=connection).get_status() - logging.info(f"forecasting job-{index}: {job_id} status: {status}") + run_train_predict_wrap_up_job(cycle_job_ids) def run_cycle( self, @@ -82,6 +103,7 @@ def run_cycle( """ Runs a single training and prediction cycle. """ + self._reattach_worker_state() logging.info( f"Starting Train-Predict cycle from {train_start} to {predict_end}" ) @@ -183,6 +205,46 @@ def run_cycle( ) return total_runtime + @staticmethod + def _get_attached_sensor(sensor: Sensor | int) -> Sensor: + sensor_id = sensor.id if isinstance(sensor, Sensor) else sensor + attached_sensor = db.session.get(Sensor, sensor_id) + if attached_sensor is None: + raise ValueError(f"Could not load sensor with id {sensor_id}.") + return attached_sensor + + @staticmethod + def _get_attached_data_source( + data_source: DataSource | int | None, + ) -> DataSource | None: + if data_source is None: + return None + source_id = ( + data_source.id if isinstance(data_source, DataSource) else data_source + ) + attached_source = db.session.get(DataSource, source_id) + if attached_source is None: + raise ValueError(f"Could not load data source with id {source_id}.") + return attached_source + + def _reattach_worker_state(self) -> None: + """Reload ORM objects through the worker's active session.""" + self._config["future_regressors"] = [ + self._get_attached_sensor(sensor) + for sensor in self._config["future_regressors"] + ] + self._config["past_regressors"] = [ + self._get_attached_sensor(sensor) + for sensor in self._config["past_regressors"] + ] + self._parameters["sensor"] = self._get_attached_sensor( + self._parameters["sensor"] + ) + self._parameters["sensor_to_save"] = self._get_attached_sensor( + self._parameters["sensor_to_save"] + ) + self._data_source = self._get_attached_data_source(self.data_source) + def _compute_forecast(self, as_job: bool = False, **kwargs) -> list[dict[str, Any]]: # Run the train-and-predict pipeline return self.run(as_job=as_job, **kwargs) @@ -295,11 +357,11 @@ def run( if as_job: cycle_job_ids = [] - # Ensure the data source is attached to the current session before - # committing. get_or_create_source() only flushes (does not commit), so - # without this merge the data source would not be found by the worker. - db.session.merge(self.data_source) + # Ensure the data source ID is available in the database when the job runs. + self._data_source = db.session.merge(self.data_source) db.session.commit() + serialized_config = self._config_schema.dump(self._config) + serialized_parameters = self._parameters_schema.dump(self._parameters) # job metadata for tracking # Serialize start and end to ISO format strings @@ -313,9 +375,16 @@ def run( for cycle_params in cycles_job_params: job = Job.create( - self.run_cycle, + run_train_predict_cycle_job, # Some cycle job params override job kwargs - kwargs={**job_kwargs, **cycle_params}, + kwargs={ + **job_kwargs, + "config": serialized_config, + "parameters": serialized_parameters, + "data_source_id": self.data_source.id, + "delete_model": self.delete_model, + **cycle_params, + }, connection=current_app.queues[queue].connection, ttl=int( current_app.config.get( @@ -343,7 +412,7 @@ def run( ) wrap_up_job = Job.create( - self.run_wrap_up, + run_train_predict_wrap_up_job, kwargs={"cycle_job_ids": cycle_job_ids}, # cycles jobs IDs to wait for connection=current_app.queues[queue].connection, depends_on=cycle_job_ids, # wrap-up job depends on all cycle jobs @@ -359,6 +428,7 @@ def run( if len(cycle_job_ids) > 1: # Return the wrap-up job ID if multiple cycle jobs are queued return {"job_id": wrap_up_job.id, "n_jobs": len(cycle_job_ids)} + return {"job_id": wrap_up_job.id, "n_jobs": len(cycle_job_ids)} else: # Return the single cycle job ID if only one job is queued return { @@ -367,5 +437,11 @@ def run( ), "n_jobs": 1, } + return { + "job_id": ( + cycle_job_ids[0] if len(cycle_job_ids) == 1 else wrap_up_job.id + ), + "n_jobs": 1, + } return self.return_values