Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
921d958
fix: enhance job ID return structure in TrainPredictPipeline
BelhsanHmida Mar 9, 2026
b7068fa
fix: update forecasting job return structure in SensorAPI
BelhsanHmida Mar 9, 2026
505c284
fix: update job fetching logic in test_train_predict_pipeline
BelhsanHmida Mar 9, 2026
10b370a
chore: remove commented-out breakpoint in test_forecasting.py
BelhsanHmida Mar 9, 2026
6342867
fix: as_job is no longer in parameters
BelhsanHmida Mar 9, 2026
cc1bf1b
fix: update job count retrieval in add_forecast function
BelhsanHmida Mar 9, 2026
a7ff4a9
fix: add connection queue to fetch job
BelhsanHmida Mar 9, 2026
745e513
style: black
Flix6x Mar 9, 2026
47f160d
docs: changelog entry
Flix6x Mar 9, 2026
3af2c62
feat: check if wrap-up job actually finished rather than failed
Flix6x Mar 9, 2026
474b860
feat: add test case for 2 cycles, yielding 2 jobs and a wrap-up job
Flix6x Mar 9, 2026
7f824a7
dev: comment out failing assert, which needs to be investgated and up…
Flix6x Mar 9, 2026
29705a2
refactor: move checking the status of the wrap-up job to where it mat…
Flix6x Mar 9, 2026
f26c41b
fix: use job ID itself in case the returned job is the one existing c…
Flix6x Mar 9, 2026
f326efc
fix: add db.commit before forecasting jobs are created
BelhsanHmida Mar 10, 2026
67862ed
dev: uncomment test assertion statement
BelhsanHmida Mar 10, 2026
64465e9
Test(feat): search all beliefs forecasts saved into the sensor by the…
BelhsanHmida Mar 10, 2026
4ac1846
test(feat): add n_cycles variable to use to account for length of for…
BelhsanHmida Mar 10, 2026
bdc5e28
style: run pre-commit
BelhsanHmida Mar 10, 2026
5340350
fix: improve assertion message in test_train_predict_pipeline for cla…
BelhsanHmida Mar 10, 2026
8271e28
Merge branch 'main' into fix/small-forecasting-pipeline-fixes
BelhsanHmida Mar 11, 2026
60a85b3
Merge branch 'main' into fix/small-forecasting-pipeline-fixes
BelhsanHmida Mar 11, 2026
4ebaa97
fix: first create all jobs, then queue all jobs, giving the db.sessio…
Flix6x Mar 17, 2026
888b980
feat: enqueue job only after the transactional request
Flix6x Mar 17, 2026
56f825a
Revert "feat: enqueue job only after the transactional request"
Flix6x Mar 17, 2026
e41c1a5
docs: resolve silent merge conflict in changelog
Flix6x Mar 17, 2026
9085b9e
Merge remote-tracking branch 'origin/main' into fix/small-forecasting…
Flix6x Mar 17, 2026
064f9cb
docs: delete duplicate changelog entry
Flix6x Mar 16, 2026
7ff82d1
docs: add release date for v0.31.2
Flix6x Mar 17, 2026
657b0c0
Merge remote-tracking branch 'origin/main' into fix/small-forecasting…
Flix6x Mar 17, 2026
ff78988
docs: advance a different bugfix to v0.31.2
Flix6x Mar 17, 2026
4ed641a
fix: self.data_source found itself in a different session somehow, so…
Flix6x Mar 17, 2026
d276c8a
Revert "fix: first create all jobs, then queue all jobs, giving the d…
Flix6x Mar 17, 2026
ddd1cf2
fix: reload forecasting pipeline orm state in worker session
BelhsanHmida Mar 18, 2026
91f0e88
fix: serialize train-predict cycle jobs for workers
BelhsanHmida Mar 18, 2026
243c011
revert: commit not related to this pr
BelhsanHmida Mar 18, 2026
38dd206
style: run pre-commit
BelhsanHmida Mar 18, 2026
1cf183b
Merge branch 'main' into fix/detached-forecasting-sensors
BelhsanHmida Mar 18, 2026
341751a
fix: fix test merge commit
BelhsanHmida Mar 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions flexmeasures/data/models/forecasting/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import pandas as pd
from darts import TimeSeries
from darts.dataprocessing.transformers import MissingValuesFiller
from flexmeasures.data.models.time_series import Sensor
from timely_beliefs import utils as tb_utils

from flexmeasures.data import db
from flexmeasures.data.models.time_series import Sensor
from flexmeasures.data.models.forecasting.exceptions import NotEnoughDataException


Expand Down Expand Up @@ -69,8 +70,10 @@ def __init__(
predict_end: datetime | None = None,
missing_threshold: float = 1.0,
) -> None:
self.future = future_regressors
self.past = past_regressors
self.future = [
self._get_attached_sensor(sensor) for sensor in future_regressors
]
self.past = [self._get_attached_sensor(sensor) for sensor in past_regressors]
self.n_steps_to_predict = n_steps_to_predict
self.max_forecast_horizon = max_forecast_horizon
# rounds up so we get the number of viewpoints, each `forecast_frequency` apart
Expand All @@ -82,15 +85,15 @@ def __init__(
self.save_belief_time = (
save_belief_time # non floored belief time to save forecasts with
)
self.target_sensor = target_sensor
self.target = f"{target_sensor.name} (ID: {target_sensor.id})_target"
self.target_sensor = self._get_attached_sensor(target_sensor)
self.target = f"{self.target_sensor.name} (ID: {self.target_sensor.id})_target"
self.future_regressors = [
f"{sensor.name} (ID: {sensor.id})_FR-{idx}"
for idx, sensor in enumerate(future_regressors)
for idx, sensor in enumerate(self.future)
]
self.past_regressors = [
f"{sensor.name} (ID: {sensor.id})_PR-{idx}"
for idx, sensor in enumerate(past_regressors)
for idx, sensor in enumerate(self.past)
]
self.predict_start = predict_start if predict_start else None
self.predict_end = predict_end if predict_end else None
Expand All @@ -102,6 +105,15 @@ def __init__(
self.forecast_frequency = forecast_frequency
self.missing_threshold = missing_threshold

@staticmethod
def _get_attached_sensor(sensor: Sensor | int) -> Sensor:
"""Reload sensors through the active session to avoid cross-session ORM state."""
sensor_id = sensor.id if isinstance(sensor, Sensor) else sensor
attached_sensor = db.session.get(Sensor, sensor_id)
if attached_sensor is None:
raise ValueError(f"Could not load sensor with id {sensor_id}.")
return attached_sensor

def load_data_all_beliefs(self) -> pd.DataFrame:
"""
This function fetches data for each sensor.
Expand Down
20 changes: 18 additions & 2 deletions flexmeasures/data/models/forecasting/pipelines/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from flexmeasures import Sensor, Source
from flexmeasures.data import db
from flexmeasures.data.models.data_sources import DataSource
from flexmeasures.data.models.forecasting.utils import data_to_bdf
from flexmeasures.data.models.forecasting.pipelines.base import BasePipeline
from flexmeasures.data.utils import save_to_db
Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(
self.quantiles = tuple(quantiles) if quantiles else None
self.forecast_horizon = np.arange(1, max_forecast_horizon + 1)
self.forecast_frequency = forecast_frequency
self.sensor_to_save = sensor_to_save
self.sensor_to_save = self._get_attached_sensor(sensor_to_save)
self.predict_start = predict_start
self.predict_end = predict_end

Expand All @@ -92,7 +93,22 @@ def __init__(
self.total_forecast_hours = (
self.max_forecast_horizon * self.sensor_resolution.total_seconds() / 3600
)
self.data_source = data_source
self.data_source = self._get_attached_data_source(data_source)

@staticmethod
def _get_attached_data_source(
data_source: Source | int | None,
) -> DataSource | None:
"""Reload the prediction source through the active session before saving beliefs."""
if data_source is None:
return None
source_id = (
data_source.id if isinstance(data_source, DataSource) else data_source
)
attached_source = db.session.get(DataSource, source_id)
if attached_source is None:
raise ValueError(f"Could not load data source with id {source_id}.")
return attached_source

def load_model(self):
"""
Expand Down
100 changes: 88 additions & 12 deletions flexmeasures/data/models/forecasting/pipelines/train_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,41 @@
from flask import current_app

from flexmeasures.data import db
from flexmeasures.data.models.data_sources import DataSource
from flexmeasures.data.models.forecasting import Forecaster
from flexmeasures.data.models.forecasting.pipelines.predict import PredictPipeline
from flexmeasures.data.models.forecasting.pipelines.train import TrainPipeline
from flexmeasures.data.models.time_series import Sensor
from flexmeasures.data.schemas.forecasting.pipeline import (
ForecasterParametersSchema,
TrainPredictPipelineConfigSchema,
)
from flexmeasures.utils.flexmeasures_inflection import p


def run_train_predict_cycle_job(
config: dict,
parameters: dict,
data_source_id: int,
delete_model: bool,
**cycle_params,
):
"""Reconstruct pipeline state inside the worker to avoid pickling ORM objects."""
pipeline = TrainPredictPipeline(config=config, delete_model=delete_model)
pipeline._parameters = pipeline._parameters_schema.load(parameters)
pipeline._data_source = db.session.get(DataSource, data_source_id)
return pipeline.run_cycle(**cycle_params)


def run_train_predict_wrap_up_job(cycle_job_ids: list[str]):
"""Log the status of all cycle jobs after completion."""
connection = current_app.queues["forecasting"].connection

for index, job_id in enumerate(cycle_job_ids):
status = Job.fetch(job_id, connection=connection).get_status()
logging.info(f"forecasting job-{index}: {job_id} status: {status}")


class TrainPredictPipeline(Forecaster):

__version__ = "1"
Expand Down Expand Up @@ -63,11 +88,7 @@ def _reattach_if_needed(obj):

def run_wrap_up(self, cycle_job_ids: list[str]):
"""Log the status of all cycle jobs after completion."""
connection = current_app.queues["forecasting"].connection

for index, job_id in enumerate(cycle_job_ids):
status = Job.fetch(job_id, connection=connection).get_status()
logging.info(f"forecasting job-{index}: {job_id} status: {status}")
run_train_predict_wrap_up_job(cycle_job_ids)

def run_cycle(
self,
Expand All @@ -82,6 +103,7 @@ def run_cycle(
"""
Runs a single training and prediction cycle.
"""
self._reattach_worker_state()
logging.info(
f"Starting Train-Predict cycle from {train_start} to {predict_end}"
)
Expand Down Expand Up @@ -183,6 +205,46 @@ def run_cycle(
)
return total_runtime

@staticmethod
def _get_attached_sensor(sensor: Sensor | int) -> Sensor:
sensor_id = sensor.id if isinstance(sensor, Sensor) else sensor
attached_sensor = db.session.get(Sensor, sensor_id)
if attached_sensor is None:
raise ValueError(f"Could not load sensor with id {sensor_id}.")
return attached_sensor

@staticmethod
def _get_attached_data_source(
data_source: DataSource | int | None,
) -> DataSource | None:
if data_source is None:
return None
source_id = (
data_source.id if isinstance(data_source, DataSource) else data_source
)
attached_source = db.session.get(DataSource, source_id)
if attached_source is None:
raise ValueError(f"Could not load data source with id {source_id}.")
return attached_source

def _reattach_worker_state(self) -> None:
"""Reload ORM objects through the worker's active session."""
self._config["future_regressors"] = [
self._get_attached_sensor(sensor)
for sensor in self._config["future_regressors"]
]
self._config["past_regressors"] = [
self._get_attached_sensor(sensor)
for sensor in self._config["past_regressors"]
]
self._parameters["sensor"] = self._get_attached_sensor(
self._parameters["sensor"]
)
self._parameters["sensor_to_save"] = self._get_attached_sensor(
self._parameters["sensor_to_save"]
)
self._data_source = self._get_attached_data_source(self.data_source)

def _compute_forecast(self, as_job: bool = False, **kwargs) -> list[dict[str, Any]]:
# Run the train-and-predict pipeline
return self.run(as_job=as_job, **kwargs)
Expand Down Expand Up @@ -295,11 +357,11 @@ def run(
if as_job:
cycle_job_ids = []

# Ensure the data source is attached to the current session before
# committing. get_or_create_source() only flushes (does not commit), so
# without this merge the data source would not be found by the worker.
db.session.merge(self.data_source)
# Ensure the data source ID is available in the database when the job runs.
self._data_source = db.session.merge(self.data_source)
db.session.commit()
serialized_config = self._config_schema.dump(self._config)
serialized_parameters = self._parameters_schema.dump(self._parameters)

# job metadata for tracking
# Serialize start and end to ISO format strings
Expand All @@ -313,9 +375,16 @@ def run(
for cycle_params in cycles_job_params:

job = Job.create(
self.run_cycle,
run_train_predict_cycle_job,
# Some cycle job params override job kwargs
kwargs={**job_kwargs, **cycle_params},
kwargs={
**job_kwargs,
"config": serialized_config,
"parameters": serialized_parameters,
"data_source_id": self.data_source.id,
"delete_model": self.delete_model,
**cycle_params,
},
connection=current_app.queues[queue].connection,
ttl=int(
current_app.config.get(
Expand Down Expand Up @@ -343,7 +412,7 @@ def run(
)

wrap_up_job = Job.create(
self.run_wrap_up,
run_train_predict_wrap_up_job,
kwargs={"cycle_job_ids": cycle_job_ids}, # cycles jobs IDs to wait for
connection=current_app.queues[queue].connection,
depends_on=cycle_job_ids, # wrap-up job depends on all cycle jobs
Expand All @@ -359,6 +428,7 @@ def run(
if len(cycle_job_ids) > 1:
# Return the wrap-up job ID if multiple cycle jobs are queued
return {"job_id": wrap_up_job.id, "n_jobs": len(cycle_job_ids)}
return {"job_id": wrap_up_job.id, "n_jobs": len(cycle_job_ids)}
else:
# Return the single cycle job ID if only one job is queued
return {
Expand All @@ -367,5 +437,11 @@ def run(
),
"n_jobs": 1,
}
return {
"job_id": (
cycle_job_ids[0] if len(cycle_job_ids) == 1 else wrap_up_job.id
),
"n_jobs": 1,
}

return self.return_values
Loading