diff --git a/CLAUDE.md b/CLAUDE.md index 667be1b..c3b3e88 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -33,7 +33,7 @@ Required environment (see `src/lib/config.py`): - `PSC_PLOT_FFMPEG_BIN` — optional, falls back to `which ffmpeg`; needed for saving animations - `PSC_PLOT_DASK_NUM_WORKERS` — optional, defaults to `os.cpu_count()` - `PSC_PLOT_DASK_CHUNK_SIZE` — optional, rows per dask partition for particle loads (default 1_000_000); reduce to bound peak memory on large files -- `PSC_PLOT_DASK_SCHEDULER` — optional; if set to `"processes"`, uses dask's processes scheduler; if `"distributed"`, spins up a `dask.distributed.LocalCluster` with `n_workers=dask_num_workers, threads_per_worker=1, processes=True`. Unset = dask default (threads). +- `PSC_PLOT_DASK_SCHEDULER` — optional; if set to `"processes"`, uses dask's processes scheduler; if `"distributed"`, spins up a `dask.distributed.LocalCluster` with `n_workers=dask_num_workers, threads_per_worker=1, processes=True`; if `"mpi"`, calls `dask_mpi.initialize(nthreads=1)` — requires the `[mpi]` extra (`pip install -e ".[mpi]"`) and launch under `mpirun`/`srun` with `-np >= 3` (rank 0 = scheduler, rank 1 = client, rest = workers; only rank 1 reaches the rest of `main()`). Unset = dask default (threads). `PSC_PLOT_DATA_DIR` is read at module-import time (`CONFIG = PscPlotConfig._load()` in `src/lib/config.py`), so it must be set in the environment before any `lib.*` import. In tests, `tests/conftest.py` sets it before importing `lib`. diff --git a/pyproject.toml b/pyproject.toml index 00f4ecb..afb1b84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ test = ["pytest>=8.0", "pytest-mpl>=0.17"] dev = ["ruff>=0.6"] dask-graph = ["graphviz>=0.20"] +mpi = ["dask-mpi>=2022.4", "mpi4py>=3.1"] [project.scripts] psc-plot = "lib.cli:main" diff --git a/src/lib/cli.py b/src/lib/cli.py index 52ece0b..2258e6f 100644 --- a/src/lib/cli.py +++ b/src/lib/cli.py @@ -58,7 +58,7 @@ def _run_dask_graph(args: Args) -> None: sys.exit(1) save_dir = args.save or Path.cwd() - save_dir.mkdir(exist_ok=True) + save_dir.mkdir(exist_ok=True, parents=True) path = save_dir / f"{args.get_save_file_stem()}.daskgraph.svg" # dask.visualize's optimize_graph flag only lowers legacy HLG collections # (e.g. dask Arrays), not new-style Expr ones (dask DataFrames) — without @@ -72,9 +72,26 @@ def _run_dask_graph(args: Args) -> None: webbrowser.open(path.absolute().as_uri()) +def _init_mpi_scheduler() -> None: + try: + from dask_mpi import initialize + except ImportError: + print( + "error: PSC_PLOT_DASK_SCHEDULER=mpi requires dask-mpi; install with 'pip install -e \".[mpi]\"'", + file=sys.stderr, + ) + sys.exit(1) + from dask.distributed import Client + + initialize(nthreads=1) + Client() + + def main(): dask.config.set(num_workers=CONFIG.dask_num_workers) - if CONFIG.dask_scheduler == "distributed": + if CONFIG.dask_scheduler == "mpi": + _init_mpi_scheduler() + elif CONFIG.dask_scheduler == "distributed": from dask.distributed import Client, LocalCluster cluster = LocalCluster(n_workers=CONFIG.dask_num_workers, threads_per_worker=1, processes=True) @@ -101,7 +118,7 @@ def main(): if args.show: plot.show() if args.save is not None: - args.save.mkdir(exist_ok=True) + args.save.mkdir(exist_ok=True, parents=True) if format not in plot.allowed_save_formats(): if format == args.save_format: # user actually specified this format