diff --git a/flexmeasures/api/v3_0/sensors.py b/flexmeasures/api/v3_0/sensors.py index 62a427321..c5d2e31b1 100644 --- a/flexmeasures/api/v3_0/sensors.py +++ b/flexmeasures/api/v3_0/sensors.py @@ -1632,6 +1632,18 @@ def trigger_forecast(self, id: int, **params): # Put the sensor to save in the parameters parameters["sensor"] = params["sensor_to_save"].id + # Check read permissions for regressor sensors specified in the config. + # The schema has already validated that these sensor IDs exist. + config = parameters.get("config", {}) + regressor_ids = set( + config.get("future-regressors", []) + + config.get("past-regressors", []) + + config.get("regressors", []) + ) + for regressor_id in regressor_ids: + regressor = db.session.get(Sensor, regressor_id) + check_access(regressor, "read") + # Set forecaster model model = parameters.pop("model", "TrainPredictPipeline") diff --git a/flexmeasures/api/v3_0/tests/test_forecasting_api.py b/flexmeasures/api/v3_0/tests/test_forecasting_api.py index df7e3fef8..fddc8288c 100644 --- a/flexmeasures/api/v3_0/tests/test_forecasting_api.py +++ b/flexmeasures/api/v3_0/tests/test_forecasting_api.py @@ -1,3 +1,5 @@ +from datetime import timedelta + from flask import current_app import isodate import pytest @@ -8,6 +10,9 @@ from flexmeasures.api.tests.utils import get_auth_token from flexmeasures.data.services.forecasting import handle_forecasting_exception from flexmeasures.data.models.forecasting.pipelines import TrainPredictPipeline +from flexmeasures.data import db +from flexmeasures.data.models.generic_assets import GenericAsset, GenericAssetType +from flexmeasures.data.models.time_series import Sensor @pytest.mark.parametrize("requesting_user", ["test_admin_user@seita.nl"], indirect=True) @@ -126,3 +131,78 @@ def test_trigger_and_fetch_forecasts( # API should return exactly these most-recent beliefs assert api_forecasts == expected_values + + +@pytest.mark.parametrize( + "regressor_field", + ["future-regressors", "past-regressors", "regressors"], +) +@pytest.mark.parametrize( + "requesting_user", ["test_supplier_user_4@seita.nl"], indirect=True +) +def test_trigger_forecast_with_unreadable_regressor_returns_403( + app, + setup_roles_users_fresh_db, + setup_accounts_fresh_db, + requesting_user, + regressor_field, +): + """Triggering a forecast that uses a regressor the requesting user cannot read must return 403.""" + + supplier_account = setup_accounts_fresh_db["Supplier"] + prosumer_account = setup_accounts_fresh_db["Prosumer"] + + asset_type = GenericAssetType(name="test-asset-type-regressor-perm") + db.session.add(asset_type) + + # Target sensor: owned by Supplier account – requesting user has create-children here + supplier_asset = GenericAsset( + name=f"supplier-target-asset-{regressor_field}", + generic_asset_type=asset_type, + owner=supplier_account, + ) + db.session.add(supplier_asset) + target_sensor = Sensor( + name=f"supplier-target-sensor-{regressor_field}", + unit="kW", + event_resolution=timedelta(hours=1), + generic_asset=supplier_asset, + ) + db.session.add(target_sensor) + + # Regressor sensor: owned by Prosumer account – requesting user has no read access here + prosumer_asset = GenericAsset( + name=f"prosumer-private-regressor-asset-{regressor_field}", + generic_asset_type=asset_type, + owner=prosumer_account, + ) + db.session.add(prosumer_asset) + private_regressor = Sensor( + name=f"prosumer-private-regressor-sensor-{regressor_field}", + unit="kW", + event_resolution=timedelta(hours=1), + generic_asset=prosumer_asset, + ) + db.session.add(private_regressor) + db.session.commit() + + client = app.test_client() + token = get_auth_token(client, "test_supplier_user_4@seita.nl", "testtest") + + payload = { + "start": "2025-01-05T00:00:00+00:00", + "end": "2025-01-05T02:00:00+00:00", + "max-forecast-horizon": "PT1H", + "forecast-frequency": "PT1H", + "config": { + "train-start": "2025-01-01T00:00:00+00:00", + "retrain-frequency": "PT1H", + regressor_field: [private_regressor.id], + }, + } + + trigger_url = url_for("SensorAPI:trigger_forecast", id=target_sensor.id) + trigger_res = client.post( + trigger_url, json=payload, headers={"Authorization": token} + ) + assert trigger_res.status_code == 403