diff --git a/CLAUDE.md b/CLAUDE.md
index 3eb91b66..667be1b2 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -31,8 +31,9 @@ Common flags:
Required environment (see `src/lib/config.py`):
- `PSC_PLOT_DATA_DIR` — directory containing the data files (must be set; `set_data_dir.sh
` is a convenience script that exports it)
- `PSC_PLOT_FFMPEG_BIN` — optional, falls back to `which ffmpeg`; needed for saving animations
-- `PSC_PLOT_DASK_NUM_WORKERS` — optional, defaults to 1
-- `PSC_PLOT_DASK_CHUNK_SIZE` — optional, rows per dask partition for particle loads (default 1_000_000); reduce to bound peak memory on large files
+- `PSC_PLOT_DASK_NUM_WORKERS` — optional, defaults to `os.cpu_count()`
+- `PSC_PLOT_DASK_CHUNK_SIZE` — optional, rows per dask partition for particle loads (default 1_000_000); reduce to bound peak memory on large files
+- `PSC_PLOT_DASK_SCHEDULER` — optional; if set to `"processes"`, uses dask's processes scheduler; if `"distributed"`, spins up a `dask.distributed.LocalCluster` with `n_workers=dask_num_workers, threads_per_worker=1, processes=True`. Unset = dask default (threads).
`PSC_PLOT_DATA_DIR` is read at module-import time (`CONFIG = PscPlotConfig._load()` in `src/lib/config.py`), so it must be set in the environment before any `lib.*` import. In tests, `tests/conftest.py` sets it before importing `lib`.
diff --git a/pyproject.toml b/pyproject.toml
index 214b2182..c1f2cfdf 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,6 +12,7 @@ dependencies = [
"xrft>=1.0",
"pandas>=2.0",
"dask>=2024.0",
+ "distributed>=2024.0",
"h5py>=3.0",
"tables>=3.10",
"scipy>=1.10",
diff --git a/src/lib/cli.py b/src/lib/cli.py
index a821ed9b..926fa935 100644
--- a/src/lib/cli.py
+++ b/src/lib/cli.py
@@ -37,6 +37,13 @@ def _resolve_save_format(args: Args) -> SaveFormat | None:
def main():
dask.config.set(num_workers=CONFIG.dask_num_workers)
+ if CONFIG.dask_scheduler == "distributed":
+ from dask.distributed import Client, LocalCluster
+
+ cluster = LocalCluster(n_workers=CONFIG.dask_num_workers, threads_per_worker=1, processes=True)
+ Client(cluster)
+ elif CONFIG.dask_scheduler:
+ dask.config.set(scheduler=CONFIG.dask_scheduler)
args = parsing.get_parsed_args()
# resolve format BEFORE applying pipeline in order to fail early
diff --git a/src/lib/config.py b/src/lib/config.py
index 388a4686..820c664f 100644
--- a/src/lib/config.py
+++ b/src/lib/config.py
@@ -1,6 +1,5 @@
import os
import shutil
-import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Self
@@ -9,6 +8,7 @@
_FFMPEG_BIN_KEY = "PSC_PLOT_FFMPEG_BIN"
_DASK_NUM_WORKERS_KEY = "PSC_PLOT_DASK_NUM_WORKERS"
_DASK_CHUNK_SIZE_KEY = "PSC_PLOT_DASK_CHUNK_SIZE"
+_DASK_SCHEDULER_KEY = "PSC_PLOT_DASK_SCHEDULER"
def parse_optional[T](s: str | None, parser: Callable[[str], T]) -> T | None:
@@ -23,6 +23,7 @@ class PscPlotConfig:
ffmpeg_bin: Path | None
dask_num_workers: int
dask_chunk_size: int
+ dask_scheduler: str | None
@classmethod
def _load(cls) -> Self:
@@ -35,15 +36,15 @@ def _load(cls) -> Self:
dask_num_workers = parse_optional(os.environ.get(_DASK_NUM_WORKERS_KEY), int)
if not dask_num_workers:
- dask_num_workers = 1
- message = f"Number of dask workers not specified; defaulting to {dask_num_workers}. Set {_DASK_NUM_WORKERS_KEY} to specify."
- warnings.warn(message)
+ dask_num_workers = os.cpu_count() or 1
dask_chunk_size = parse_optional(os.environ.get(_DASK_CHUNK_SIZE_KEY), int)
if not dask_chunk_size:
dask_chunk_size = 1_000_000
- return cls(data_dir, ffmpeg_bin, dask_num_workers, dask_chunk_size)
+ dask_scheduler = os.environ.get(_DASK_SCHEDULER_KEY) or None
+
+ return cls(data_dir, ffmpeg_bin, dask_num_workers, dask_chunk_size, dask_scheduler)
CONFIG = PscPlotConfig._load()
diff --git a/src/lib/data/loaders/field_bp.py b/src/lib/data/loaders/field_bp.py
index 3b9765d6..2f1a3836 100644
--- a/src/lib/data/loaders/field_bp.py
+++ b/src/lib/data/loaders/field_bp.py
@@ -18,6 +18,10 @@ def _get_path(prefix: str, step: int) -> Path:
return CONFIG.data_dir / f"{prefix}.{step:09}.bp"
+def _decode_psc(ds):
+ return pscpy.decode_psc(ds, ["e", "i"])
+
+
@loader
class FieldLoaderBp(Loader):
@classmethod
@@ -34,7 +38,8 @@ def get_data(self) -> Field:
paths=[_get_path(self.prefix, step) for step in self.steps],
combine="nested",
concat_dim="t",
- preprocess=lambda ds: pscpy.decode_psc(ds, ["e", "i"]),
+ preprocess=_decode_psc,
+ parallel=True,
)
if self.active_key is not None:
derive_field_variable(ds, self.active_key, self.prefix)
diff --git a/src/lib/derived_particle_variables/derived_particle_variable.py b/src/lib/derived_particle_variables/derived_particle_variable.py
index f400ef74..060b4550 100644
--- a/src/lib/derived_particle_variables/derived_particle_variable.py
+++ b/src/lib/derived_particle_variables/derived_particle_variable.py
@@ -3,6 +3,7 @@
import pandas as pd
+from lib import var_info_registry
from lib.data.data_with_attrs import List
__all__ = ["derived_particle_variable", "derive_particle_variable", "DERIVED_PARTICLE_VARIABLES"]
@@ -25,7 +26,10 @@ def __init__(
def assign_to(self, data: List) -> List:
df = data.data
- return data.assign_data(df.assign(**{self.name: self.derive(*(df[base_var_name] for base_var_name in self.base_var_names))}))
+
+ info = var_info_registry.lookup("prt", self.name)
+ new_var_infos = {**data.metadata.var_infos, self.name: info}
+ return data.assign_data(df.assign(**{self.name: self.derive(*(df[base_var_name] for base_var_name in self.base_var_names))})).assign_metadata(var_infos=new_var_infos)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(({', '.join(self.base_var_names)}) -> {self.name}: {self.derive!r})"
diff --git a/tests/baseline/test_hamscan.png b/tests/baseline/test_hamscan.png
new file mode 100644
index 00000000..4a2607ee
Binary files /dev/null and b/tests/baseline/test_hamscan.png differ
diff --git a/tests/test_plots.py b/tests/test_plots.py
index 3881d74c..bd8e4782 100644
--- a/tests/test_plots.py
+++ b/tests/test_plots.py
@@ -151,6 +151,12 @@ def test_static_scatter_bp():
return make_plot("prt.e -i t=-1 -v y z time= --grid y=0.0625 z=0.0625".split(), data_dir="test-3d")
+@pytest.mark.mpl_image_compare(**MPL_KWARGS)
+def test_hamscan():
+ """Archetypal scan for hammerhead distributions ("hams")."""
+ return make_plot("prt.e -i t=-1 --derive pzx --bin y py=20 pzx=20 t= -v py pzx time=y --compute".split(), data_dir="test-2d")
+
+
# --- Particle moments ---