Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions flexmeasures/api/v3_0/sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,6 +1632,18 @@ def trigger_forecast(self, id: int, **params):
# Put the sensor to save in the parameters
parameters["sensor"] = params["sensor_to_save"].id

# Check read permissions for regressor sensors specified in the config.
# The schema has already validated that these sensor IDs exist.
config = parameters.get("config", {})
regressor_ids = set(
config.get("future-regressors", [])
+ config.get("past-regressors", [])
+ config.get("regressors", [])
)
for regressor_id in regressor_ids:
regressor = db.session.get(Sensor, regressor_id)
check_access(regressor, "read")

# Set forecaster model
model = parameters.pop("model", "TrainPredictPipeline")

Expand Down
80 changes: 80 additions & 0 deletions flexmeasures/api/v3_0/tests/test_forecasting_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import timedelta

from flask import current_app
import isodate
import pytest
Expand All @@ -8,6 +10,9 @@
from flexmeasures.api.tests.utils import get_auth_token
from flexmeasures.data.services.forecasting import handle_forecasting_exception
from flexmeasures.data.models.forecasting.pipelines import TrainPredictPipeline
from flexmeasures.data import db
from flexmeasures.data.models.generic_assets import GenericAsset, GenericAssetType
from flexmeasures.data.models.time_series import Sensor


@pytest.mark.parametrize("requesting_user", ["test_admin_user@seita.nl"], indirect=True)
Expand Down Expand Up @@ -126,3 +131,78 @@ def test_trigger_and_fetch_forecasts(

# API should return exactly these most-recent beliefs
assert api_forecasts == expected_values


@pytest.mark.parametrize(
"regressor_field",
["future-regressors", "past-regressors", "regressors"],
)
@pytest.mark.parametrize(
"requesting_user", ["test_supplier_user_4@seita.nl"], indirect=True
)
def test_trigger_forecast_with_unreadable_regressor_returns_403(
app,
setup_roles_users_fresh_db,
setup_accounts_fresh_db,
requesting_user,
regressor_field,
):
"""Triggering a forecast that uses a regressor the requesting user cannot read must return 403."""

supplier_account = setup_accounts_fresh_db["Supplier"]
prosumer_account = setup_accounts_fresh_db["Prosumer"]

asset_type = GenericAssetType(name="test-asset-type-regressor-perm")
db.session.add(asset_type)

# Target sensor: owned by Supplier account – requesting user has create-children here
supplier_asset = GenericAsset(
name=f"supplier-target-asset-{regressor_field}",
generic_asset_type=asset_type,
owner=supplier_account,
)
db.session.add(supplier_asset)
target_sensor = Sensor(
name=f"supplier-target-sensor-{regressor_field}",
unit="kW",
event_resolution=timedelta(hours=1),
generic_asset=supplier_asset,
)
db.session.add(target_sensor)

# Regressor sensor: owned by Prosumer account – requesting user has no read access here
prosumer_asset = GenericAsset(
name=f"prosumer-private-regressor-asset-{regressor_field}",
generic_asset_type=asset_type,
owner=prosumer_account,
)
db.session.add(prosumer_asset)
private_regressor = Sensor(
name=f"prosumer-private-regressor-sensor-{regressor_field}",
unit="kW",
event_resolution=timedelta(hours=1),
generic_asset=prosumer_asset,
)
db.session.add(private_regressor)
db.session.commit()

client = app.test_client()
token = get_auth_token(client, "test_supplier_user_4@seita.nl", "testtest")

payload = {
"start": "2025-01-05T00:00:00+00:00",
"end": "2025-01-05T02:00:00+00:00",
"max-forecast-horizon": "PT1H",
"forecast-frequency": "PT1H",
"config": {
"train-start": "2025-01-01T00:00:00+00:00",
"retrain-frequency": "PT1H",
regressor_field: [private_regressor.id],
},
}

trigger_url = url_for("SensorAPI:trigger_forecast", id=target_sensor.id)
trigger_res = client.post(
trigger_url, json=payload, headers={"Authorization": token}
)
assert trigger_res.status_code == 403
Loading