From 920a89e6bed2eee34b464f6363588c59de69449f Mon Sep 17 00:00:00 2001 From: Fabian Date: Tue, 7 Apr 2026 20:56:00 +0200 Subject: [PATCH] perf: load cutout into memory and use dask threads --- rules/collect.smk | 16 ++++++++++ scripts/_helpers.py | 42 ++++++++++++++++++++++--- scripts/build_daily_heat_demand.py | 6 ++-- scripts/build_hac_features.py | 9 ++---- scripts/build_line_rating.py | 12 +++---- scripts/build_renewable_profiles.py | 9 ++---- scripts/build_solar_thermal_profiles.py | 6 ++-- scripts/build_temperature_profiles.py | 7 ++--- 8 files changed, 69 insertions(+), 38 deletions(-) diff --git a/rules/collect.smk b/rules/collect.smk index 743407ee92..2f5b00271c 100644 --- a/rules/collect.smk +++ b/rules/collect.smk @@ -30,6 +30,22 @@ rule process_costs: ), +rule create_renewable_profiles: + input: + expand( + resources("profile_{clusters}_{tech}.nc"), + tech=[ + tech + for tech in config["electricity"]["renewable_carriers"] + if tech != "hydro" + ], + **config["scenario"], + run=config["run"]["name"], + ), + message: + "Collection renewable profiles." + + rule cluster_networks: message: "Collecting clustered network files" diff --git a/scripts/_helpers.py b/scripts/_helpers.py index ad314aab28..7d4c175ca7 100644 --- a/scripts/_helpers.py +++ b/scripts/_helpers.py @@ -1024,12 +1024,37 @@ def rename_techs(label: str) -> str: return label +def _get_netcdf_chunk_sizes(path: str) -> dict[str, int]: + """Read chunk sizes from a netCDF file to align dask chunks with on-disk layout.""" + import netCDF4 + + nc = netCDF4.Dataset(path) + chunks = {} + for v in nc.variables.values(): + chunking = v.chunking() + if not isinstance(chunking, list) or len(v.dimensions) < 3: + continue + for dim, size in zip(v.dimensions, chunking): + if dim not in chunks: + chunks[dim] = size + break + nc.close() + return chunks + + def load_cutout( - cutout_files: str | list[str], time: None | pd.DatetimeIndex = None + cutout_files: str | list[str], + time: None | pd.DatetimeIndex = None, + chunks: dict | None = None, ) -> atlite.Cutout: """ Load and optionally combine multiple cutout files. + Reads chunk sizes from the netCDF file on disk to align dask chunks with + the storage layout, loads data eagerly into memory, then re-chunks as dask + arrays so downstream computation can use the threaded scheduler without + HDF5 thread-safety issues. + Parameters ---------- cutout_files : str or list of str @@ -1037,22 +1062,31 @@ def load_cutout( If a list is provided, the cutouts will be concatenated along the time dimension. time : pd.DatetimeIndex, optional If provided, select only the specified times from the cutout. + chunks : dict, optional + Dask chunk sizes for the returned cutout. If None, reads chunk sizes + from the netCDF file. Returns ------- atlite.Cutout - Merged cutout with optional time selection applied. + Cutout with in-memory data re-chunked as dask arrays. """ + first_file = cutout_files if isinstance(cutout_files, str) else cutout_files[0] + if chunks is None: + chunks = _get_netcdf_chunk_sizes(first_file) or {"time": 100} + if isinstance(cutout_files, str): - cutout = atlite.Cutout(cutout_files) + cutout = atlite.Cutout(cutout_files, chunks=chunks) elif isinstance(cutout_files, list): - cutout_da = [atlite.Cutout(c).data for c in cutout_files] + cutout_da = [atlite.Cutout(c, chunks=chunks).data for c in cutout_files] combined_data = xr.concat(cutout_da, dim="time", data_vars="minimal") cutout = atlite.Cutout(NamedTemporaryFile().name, data=combined_data) if time is not None: cutout.data = cutout.data.sel(time=time) + cutout.data = cutout.data.load().chunk(chunks) + return cutout diff --git a/scripts/build_daily_heat_demand.py b/scripts/build_daily_heat_demand.py index 2ff6f73d4c..358c48857c 100644 --- a/scripts/build_daily_heat_demand.py +++ b/scripts/build_daily_heat_demand.py @@ -16,10 +16,10 @@ import logging +import dask import geopandas as gpd import numpy as np import xarray as xr -from dask.distributed import Client, LocalCluster from scripts._helpers import ( configure_logging, @@ -43,8 +43,7 @@ set_scenario_config(snakemake) nprocesses = int(snakemake.threads) - cluster = LocalCluster(n_workers=nprocesses, threads_per_worker=1) - client = Client(cluster, asynchronous=True) + dask.config.set(scheduler="threads", num_workers=nprocesses) cutout_name = snakemake.input.cutout @@ -71,7 +70,6 @@ heat_demand = cutout.heat_demand( matrix=M.T, index=clustered_regions.index, - dask_kwargs=dict(scheduler=client), show_progress=False, ).sel(time=daily) diff --git a/scripts/build_hac_features.py b/scripts/build_hac_features.py index 82c4d4e574..31968ededf 100644 --- a/scripts/build_hac_features.py +++ b/scripts/build_hac_features.py @@ -7,9 +7,9 @@ import logging +import dask import geopandas as gpd from atlite.aggregate import aggregate_matrix -from dask.distributed import Client from scripts._helpers import ( configure_logging, @@ -31,10 +31,7 @@ params = snakemake.params nprocesses = int(snakemake.threads) - if nprocesses > 1: - client = Client(n_workers=nprocesses, threads_per_worker=1) - else: - client = None + dask.config.set(scheduler="threads", num_workers=nprocesses) time = get_snapshots(params.snapshots, params.drop_leap_day) @@ -47,6 +44,6 @@ aggregate_matrix, matrix=I, index=regions.index ) - ds = ds.load(scheduler=client) + ds = ds.load() ds.to_netcdf(snakemake.output[0]) diff --git a/scripts/build_line_rating.py b/scripts/build_line_rating.py index 2d15f7113a..fb9516603d 100755 --- a/scripts/build_line_rating.py +++ b/scripts/build_line_rating.py @@ -29,11 +29,11 @@ import re import atlite +import dask import geopandas as gpd import numpy as np import pypsa import xarray as xr -from dask.distributed import Client from shapely.geometry import LineString as Line from shapely.geometry import Point @@ -73,7 +73,7 @@ def calculate_line_rating( n: pypsa.Network, cutout: atlite.Cutout, show_progress: bool = True, - dask_kwargs: dict = None, + dask_kwargs: dict | None = None, ) -> xr.DataArray: """ Calculates the maximal allowed power flow in each line for each time step @@ -144,16 +144,12 @@ def calculate_line_rating( nprocesses = int(snakemake.threads) show_progress = not snakemake.config["run"].get("disable_progressbar", True) show_progress = show_progress and snakemake.config["atlite"]["show_progress"] - if nprocesses > 1: - client = Client(n_workers=nprocesses, threads_per_worker=1) - else: - client = None - dask_kwargs = {"scheduler": client} + dask.config.set(scheduler="threads", num_workers=nprocesses) n = pypsa.Network(snakemake.input.base_network) time = get_snapshots(snakemake.params.snapshots, snakemake.params.drop_leap_day) cutout = load_cutout(snakemake.input.cutout, time=time) - da = calculate_line_rating(n, cutout, show_progress, dask_kwargs) + da = calculate_line_rating(n, cutout, show_progress) da.to_netcdf(snakemake.output[0]) diff --git a/scripts/build_renewable_profiles.py b/scripts/build_renewable_profiles.py index 1f740f5c3f..b3760ffe9d 100644 --- a/scripts/build_renewable_profiles.py +++ b/scripts/build_renewable_profiles.py @@ -92,12 +92,12 @@ import time from itertools import product +import dask import geopandas as gpd import numpy as np import pandas as pd import xarray as xr from atlite.gis import ExclusionContainer -from dask.distributed import Client from scripts._helpers import ( configure_logging, @@ -140,10 +140,7 @@ if correction_factor != 1.0: logger.info(f"correction_factor is set as {correction_factor}") - if nprocesses > 1: - client = Client(n_workers=nprocesses, threads_per_worker=1) - else: - client = None + dask.config.set(scheduler="threads", num_workers=nprocesses) sns = get_snapshots(snakemake.params.snapshots, snakemake.params.drop_leap_day) @@ -173,8 +170,6 @@ ) func = getattr(cutout, resource.pop("method")) - if client is not None: - resource["dask_kwargs"] = {"scheduler": client} logger.info( f"Calculate average capacity factor per grid cell for technology {technology}..." diff --git a/scripts/build_solar_thermal_profiles.py b/scripts/build_solar_thermal_profiles.py index 8e0b845ce3..0375203049 100644 --- a/scripts/build_solar_thermal_profiles.py +++ b/scripts/build_solar_thermal_profiles.py @@ -13,10 +13,10 @@ import logging +import dask import geopandas as gpd import numpy as np import xarray as xr -from dask.distributed import Client, LocalCluster from scripts._helpers import ( configure_logging, @@ -36,8 +36,7 @@ set_scenario_config(snakemake) nprocesses = int(snakemake.threads) - cluster = LocalCluster(n_workers=nprocesses, threads_per_worker=1) - client = Client(cluster, asynchronous=True) + dask.config.set(scheduler="threads", num_workers=nprocesses) config = snakemake.params.solar_thermal config.pop("cutout", None) @@ -65,7 +64,6 @@ **config, matrix=M_tilde.T, index=clustered_regions.index, - dask_kwargs=dict(scheduler=client), show_progress=False, ) diff --git a/scripts/build_temperature_profiles.py b/scripts/build_temperature_profiles.py index f642578e2e..3e81e2e680 100644 --- a/scripts/build_temperature_profiles.py +++ b/scripts/build_temperature_profiles.py @@ -15,10 +15,10 @@ import logging +import dask import geopandas as gpd import numpy as np import xarray as xr -from dask.distributed import Client, LocalCluster from scripts._helpers import ( configure_logging, @@ -41,8 +41,7 @@ set_scenario_config(snakemake) nprocesses = int(snakemake.threads) - cluster = LocalCluster(n_workers=nprocesses, threads_per_worker=1) - client = Client(cluster, asynchronous=True) + dask.config.set(scheduler="threads", num_workers=nprocesses) time = get_snapshots(snakemake.params.snapshots, snakemake.params.drop_leap_day) @@ -66,7 +65,6 @@ temp_air = cutout.temperature( matrix=M_tilde.T, index=clustered_regions.index, - dask_kwargs=dict(scheduler=client), show_progress=False, ) @@ -75,7 +73,6 @@ temp_soil = cutout.soil_temperature( matrix=M_tilde.T, index=clustered_regions.index, - dask_kwargs=dict(scheduler=client), show_progress=False, )