diff --git a/CLAUDE.md b/CLAUDE.md index 3eb91b66..667be1b2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -31,8 +31,9 @@ Common flags: Required environment (see `src/lib/config.py`): - `PSC_PLOT_DATA_DIR` — directory containing the data files (must be set; `set_data_dir.sh ` is a convenience script that exports it) - `PSC_PLOT_FFMPEG_BIN` — optional, falls back to `which ffmpeg`; needed for saving animations -- `PSC_PLOT_DASK_NUM_WORKERS` — optional, defaults to 1 -- `PSC_PLOT_DASK_CHUNK_SIZE` — optional, rows per dask partition for particle loads (default 1_000_000); reduce to bound peak memory on large files +- `PSC_PLOT_DASK_NUM_WORKERS` — optional, defaults to `os.cpu_count()` +- `PSC_PLOT_DASK_CHUNK_SIZE` — optional, rows per dask partition for particle loads (default 1_000_000); reduce to bound peak memory on large files +- `PSC_PLOT_DASK_SCHEDULER` — optional; if set to `"processes"`, uses dask's processes scheduler; if `"distributed"`, spins up a `dask.distributed.LocalCluster` with `n_workers=dask_num_workers, threads_per_worker=1, processes=True`. Unset = dask default (threads). `PSC_PLOT_DATA_DIR` is read at module-import time (`CONFIG = PscPlotConfig._load()` in `src/lib/config.py`), so it must be set in the environment before any `lib.*` import. In tests, `tests/conftest.py` sets it before importing `lib`. diff --git a/pyproject.toml b/pyproject.toml index 214b2182..c1f2cfdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "xrft>=1.0", "pandas>=2.0", "dask>=2024.0", + "distributed>=2024.0", "h5py>=3.0", "tables>=3.10", "scipy>=1.10", diff --git a/src/lib/cli.py b/src/lib/cli.py index a821ed9b..926fa935 100644 --- a/src/lib/cli.py +++ b/src/lib/cli.py @@ -37,6 +37,13 @@ def _resolve_save_format(args: Args) -> SaveFormat | None: def main(): dask.config.set(num_workers=CONFIG.dask_num_workers) + if CONFIG.dask_scheduler == "distributed": + from dask.distributed import Client, LocalCluster + + cluster = LocalCluster(n_workers=CONFIG.dask_num_workers, threads_per_worker=1, processes=True) + Client(cluster) + elif CONFIG.dask_scheduler: + dask.config.set(scheduler=CONFIG.dask_scheduler) args = parsing.get_parsed_args() # resolve format BEFORE applying pipeline in order to fail early diff --git a/src/lib/config.py b/src/lib/config.py index 388a4686..820c664f 100644 --- a/src/lib/config.py +++ b/src/lib/config.py @@ -1,6 +1,5 @@ import os import shutil -import warnings from dataclasses import dataclass from pathlib import Path from typing import Callable, Self @@ -9,6 +8,7 @@ _FFMPEG_BIN_KEY = "PSC_PLOT_FFMPEG_BIN" _DASK_NUM_WORKERS_KEY = "PSC_PLOT_DASK_NUM_WORKERS" _DASK_CHUNK_SIZE_KEY = "PSC_PLOT_DASK_CHUNK_SIZE" +_DASK_SCHEDULER_KEY = "PSC_PLOT_DASK_SCHEDULER" def parse_optional[T](s: str | None, parser: Callable[[str], T]) -> T | None: @@ -23,6 +23,7 @@ class PscPlotConfig: ffmpeg_bin: Path | None dask_num_workers: int dask_chunk_size: int + dask_scheduler: str | None @classmethod def _load(cls) -> Self: @@ -35,15 +36,15 @@ def _load(cls) -> Self: dask_num_workers = parse_optional(os.environ.get(_DASK_NUM_WORKERS_KEY), int) if not dask_num_workers: - dask_num_workers = 1 - message = f"Number of dask workers not specified; defaulting to {dask_num_workers}. Set {_DASK_NUM_WORKERS_KEY} to specify." - warnings.warn(message) + dask_num_workers = os.cpu_count() or 1 dask_chunk_size = parse_optional(os.environ.get(_DASK_CHUNK_SIZE_KEY), int) if not dask_chunk_size: dask_chunk_size = 1_000_000 - return cls(data_dir, ffmpeg_bin, dask_num_workers, dask_chunk_size) + dask_scheduler = os.environ.get(_DASK_SCHEDULER_KEY) or None + + return cls(data_dir, ffmpeg_bin, dask_num_workers, dask_chunk_size, dask_scheduler) CONFIG = PscPlotConfig._load() diff --git a/src/lib/data/loaders/field_bp.py b/src/lib/data/loaders/field_bp.py index 3b9765d6..2f1a3836 100644 --- a/src/lib/data/loaders/field_bp.py +++ b/src/lib/data/loaders/field_bp.py @@ -18,6 +18,10 @@ def _get_path(prefix: str, step: int) -> Path: return CONFIG.data_dir / f"{prefix}.{step:09}.bp" +def _decode_psc(ds): + return pscpy.decode_psc(ds, ["e", "i"]) + + @loader class FieldLoaderBp(Loader): @classmethod @@ -34,7 +38,8 @@ def get_data(self) -> Field: paths=[_get_path(self.prefix, step) for step in self.steps], combine="nested", concat_dim="t", - preprocess=lambda ds: pscpy.decode_psc(ds, ["e", "i"]), + preprocess=_decode_psc, + parallel=True, ) if self.active_key is not None: derive_field_variable(ds, self.active_key, self.prefix) diff --git a/src/lib/derived_particle_variables/derived_particle_variable.py b/src/lib/derived_particle_variables/derived_particle_variable.py index f400ef74..060b4550 100644 --- a/src/lib/derived_particle_variables/derived_particle_variable.py +++ b/src/lib/derived_particle_variables/derived_particle_variable.py @@ -3,6 +3,7 @@ import pandas as pd +from lib import var_info_registry from lib.data.data_with_attrs import List __all__ = ["derived_particle_variable", "derive_particle_variable", "DERIVED_PARTICLE_VARIABLES"] @@ -25,7 +26,10 @@ def __init__( def assign_to(self, data: List) -> List: df = data.data - return data.assign_data(df.assign(**{self.name: self.derive(*(df[base_var_name] for base_var_name in self.base_var_names))})) + + info = var_info_registry.lookup("prt", self.name) + new_var_infos = {**data.metadata.var_infos, self.name: info} + return data.assign_data(df.assign(**{self.name: self.derive(*(df[base_var_name] for base_var_name in self.base_var_names))})).assign_metadata(var_infos=new_var_infos) def __repr__(self) -> str: return f"{self.__class__.__name__}(({', '.join(self.base_var_names)}) -> {self.name}: {self.derive!r})" diff --git a/tests/baseline/test_hamscan.png b/tests/baseline/test_hamscan.png new file mode 100644 index 00000000..4a2607ee Binary files /dev/null and b/tests/baseline/test_hamscan.png differ diff --git a/tests/test_plots.py b/tests/test_plots.py index 3881d74c..bd8e4782 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -151,6 +151,12 @@ def test_static_scatter_bp(): return make_plot("prt.e -i t=-1 -v y z time= --grid y=0.0625 z=0.0625".split(), data_dir="test-3d") +@pytest.mark.mpl_image_compare(**MPL_KWARGS) +def test_hamscan(): + """Archetypal scan for hammerhead distributions ("hams").""" + return make_plot("prt.e -i t=-1 --derive pzx --bin y py=20 pzx=20 t= -v py pzx time=y --compute".split(), data_dir="test-2d") + + # --- Particle moments ---