diff --git a/pyproject.toml b/pyproject.toml index c1f2cfd..00f4ecb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ [project.optional-dependencies] test = ["pytest>=8.0", "pytest-mpl>=0.17"] dev = ["ruff>=0.6"] +dask-graph = ["graphviz>=0.20"] [project.scripts] psc-plot = "lib.cli:main" diff --git a/src/lib/cli.py b/src/lib/cli.py index 799b608..52ece0b 100644 --- a/src/lib/cli.py +++ b/src/lib/cli.py @@ -1,5 +1,7 @@ import sys import warnings +import webbrowser +from pathlib import Path import dask import matplotlib.pyplot as plt @@ -35,6 +37,41 @@ def _resolve_save_format(args: Args) -> SaveFormat | None: return "gif" +def _run_dask_graph(args: Args) -> None: + data = args.get_data() + + collections = data.dask_collections() + if not collections: + print( + f"error: --dask-graph requires dask-backed data; pipeline produced eager {type(data).__name__}", + file=sys.stderr, + ) + sys.exit(1) + + try: + import graphviz # noqa: F401 + except ImportError: + print( + "error: --dask-graph requires the 'graphviz' package; install with `pip install -e \".[dask-graph]\"`", + file=sys.stderr, + ) + sys.exit(1) + + save_dir = args.save or Path.cwd() + save_dir.mkdir(exist_ok=True) + path = save_dir / f"{args.get_save_file_stem()}.daskgraph.svg" + # dask.visualize's optimize_graph flag only lowers legacy HLG collections + # (e.g. dask Arrays), not new-style Expr ones (dask DataFrames) — without + # pre-optimizing the latter, un-lowered nodes (e.g. Concat from dd.concat) + # fail with NotImplementedError in _layer. + collections = [c.optimize() if hasattr(c, "optimize") else c for c in collections] + dask.visualize(*collections, filename=str(path), optimize_graph=True) + print(f"wrote to {path}") + + if args.show: + webbrowser.open(path.absolute().as_uri()) + + def main(): dask.config.set(num_workers=CONFIG.dask_num_workers) if CONFIG.dask_scheduler == "distributed": @@ -46,6 +83,13 @@ def main(): dask.config.set(scheduler=CONFIG.dask_scheduler) args = parsing.get_parsed_args() + + if args.dask_graph: + if args.save_format is not None: + warnings.warn("--save-format is ignored with --dask-graph") + _run_dask_graph(args) + return + # resolve format BEFORE applying pipeline in order to fail early format = _resolve_save_format(args) diff --git a/src/lib/data/data_with_attrs.py b/src/lib/data/data_with_attrs.py index 03d4597..6af3e15 100644 --- a/src/lib/data/data_with_attrs.py +++ b/src/lib/data/data_with_attrs.py @@ -110,6 +110,9 @@ def lower_bound(self, dim_name: str) -> float: ... @abstractmethod def upper_bound(self, dim_name: str) -> float: ... + @abstractmethod + def dask_collections(self) -> list: ... + @dataclass(kw_only=True, frozen=True) class FieldMetadata(Metadata): @@ -166,6 +169,9 @@ def var_bounds(self) -> tuple[float, float]: active = self.active_data return dask.compute(np.min(active), np.max(active)) + def dask_collections(self) -> list: + return [da.data for da in self.data.data_vars.values() if dask.is_dask_collection(da.data)] + @dataclass(kw_only=True, frozen=True) class ListMetadata(Metadata): @@ -240,6 +246,9 @@ def upper_bound(self, dim_name) -> float: cache[dim_name] = self.data[dim_name].max(skipna=True) return cache[dim_name] + def dask_collections(self) -> list: + return [] + class LazyList(List[dd.DataFrame]): data: dd.DataFrame @@ -266,3 +275,6 @@ def lower_bound(self, dim_name) -> float: def upper_bound(self, dim_name) -> float: return self.bounds(dim_name)[1] + + def dask_collections(self) -> list: + return [self.data] diff --git a/src/lib/data/loaders/particle_bp.py b/src/lib/data/loaders/particle_bp.py index 0ff72af..85645ae 100644 --- a/src/lib/data/loaders/particle_bp.py +++ b/src/lib/data/loaders/particle_bp.py @@ -3,6 +3,7 @@ import dask.dataframe as dd import numpy as np +import pandas as pd import xarray as xr from lib.config import CONFIG @@ -24,22 +25,46 @@ def _read_attrs(path: pathlib.Path) -> dict: return {k: ds.attrs[k] for k in ds.attrs} -def _load_step_df(path: pathlib.Path, time: float) -> dd.DataFrame: - """Open one BP step lazily and return a per-step dask DataFrame with a - constant `t` column. Drops the BP-assigned particle-dim index column. +def _peek_size(path: pathlib.Path) -> tuple[str, int]: + """Return the file's particle-dim name and length without reading data.""" + with xr.open_dataset(path) as ds: + return next((d, n) for d, n in ds.sizes.items() if n > 1) + - Note: the `t` column is added via map_partitions rather than - dd.DataFrame.assign — the latter creates a broadcast-scalar column whose - `to_dask_array()` trips an IndexError in dask-expr's optimizer when the - dataframe came from xarray's to_dask_dataframe. map_partitions produces a - proper per-row column that survives the optimizer. - """ - with xr.open_dataset(path) as raw: - particle_dim = next(d for d, n in raw.sizes.items() if n > 1) - ds = xr.open_dataset(path, chunks={particle_dim: CONFIG.dask_chunk_size}).squeeze(drop=True) - df = ds.to_dask_dataframe().drop(columns=[particle_dim]) - df = df.map_partitions(lambda p, t: p.assign(t=t), np.float64(time)) - return df +def _build_meta(path: pathlib.Path) -> pd.DataFrame: + """Build an empty pandas DataFrame whose columns/dtypes match a per-partition read.""" + with xr.open_dataset(path) as ds: + particle_dim = next(d for d, n in ds.sizes.items() if n > 1) + dtypes = {var: ds[var].dtype for var in ds.data_vars if var != particle_dim} + meta = pd.DataFrame({var: pd.Series(dtype=dt) for var, dt in dtypes.items()}) + meta["t"] = pd.Series(dtype=np.float64) + return meta + + +def _read_chunk( + path: pathlib.Path, + time: float, + particle_dim: str, + slice_obj: slice, + columns: list[str] | None = None, +) -> pd.DataFrame: + """Read one chunk-slice of one BP file as a pandas DataFrame with a `t` + column. The `columns` keyword is populated by dask-expr's column-projection + optimizer; when supplied, only those variables are read from disk.""" + ds = xr.open_dataset(path) + wanted_vars = [c for c in columns if c != "t" and c in ds.data_vars] if columns is not None else [v for v in ds.data_vars if v != particle_dim] + if wanted_vars: + sliced = ds[wanted_vars].isel({particle_dim: slice_obj}).squeeze(drop=True) + pdf = pd.DataFrame({var: np.asarray(sliced[var].values) for var in sliced.data_vars}) + else: + start, stop, step = slice_obj.indices(ds.sizes[particle_dim]) + n_rows = max(0, (stop - start + step - 1) // step) + pdf = pd.DataFrame(index=pd.RangeIndex(n_rows)) + if columns is None or "t" in columns: + pdf["t"] = np.float64(time) + if columns is not None: + pdf = pdf[[c for c in columns if c in pdf.columns]] + return pdf _SPECIES_KEY_RE = re.compile(r"^([a-zA-Z]+)([+-]*)(\d*)$") @@ -87,14 +112,31 @@ def get_data(self) -> LazyList: info = SpeciesInfo(self.species_key, display, q, m) species_dict = {self.species_key: info} - dfs = [_load_step_df(_get_path(self.prefix, step), time) for step, time in zip(self.steps, times)] - df = dd.concat(dfs) - + # Build per-partition iterables for dd.from_map, chunking each file + # along its particle dim. dd.from_map propagates downstream column + # projection into `_read_chunk` via its `columns` kwarg, so unused + # variables are never read from disk. + chunk_size = CONFIG.dask_chunk_size + paths: list[pathlib.Path] = [] + step_times: list[float] = [] + particle_dims: list[str] = [] + slices: list[slice] = [] partition_ranges = [] offset = 0 - for d in dfs: - partition_ranges.append((offset, offset + d.npartitions)) - offset += d.npartitions + for step, time in zip(self.steps, times): + path = _get_path(self.prefix, step) + particle_dim, n = _peek_size(path) + n_chunks = max(1, (n + chunk_size - 1) // chunk_size) + partition_ranges.append((offset, offset + n_chunks)) + offset += n_chunks + for i in range(n_chunks): + paths.append(path) + step_times.append(float(time)) + particle_dims.append(particle_dim) + slices.append(slice(i * chunk_size, (i + 1) * chunk_size)) + + meta = _build_meta(paths[0]) + df = dd.from_map(_read_chunk, paths, step_times, particle_dims, slices, meta=meta) corners = np.asarray(head["corner"]) lengths = np.asarray(head["length"]) diff --git a/src/lib/parsing/args.py b/src/lib/parsing/args.py index 7d4dd3b..9934a26 100644 --- a/src/lib/parsing/args.py +++ b/src/lib/parsing/args.py @@ -4,6 +4,7 @@ from lib.data.adaptor import Adaptor from lib.data.compile import compile_source from lib.data.data_source import DataSource +from lib.data.data_with_attrs import DataWithAttrs from lib.plotting.get_plot import get_plot from lib.plotting.hook import Hook from lib.plotting.plot import Plot @@ -18,10 +19,14 @@ class Args(argparse.Namespace): show: bool save: Path | None save_format: str | None + dask_graph: bool - def get_animation(self) -> Plot: + def get_data(self) -> DataWithAttrs: source = compile_source(self.loader, self.adaptors) - data = source.get_data() + return source.get_data() + + def get_animation(self) -> Plot: + data = self.get_data() plot = get_plot(data) diff --git a/src/lib/parsing/parse.py b/src/lib/parsing/parse.py index daa14ca..a9cd1d9 100644 --- a/src/lib/parsing/parse.py +++ b/src/lib/parsing/parse.py @@ -31,6 +31,11 @@ def _get_parser(prefixes: Iterable[str]) -> argparse.ArgumentParser: default=None, help="format for saved animations (default: mp4, falls back to gif if ffmpeg unavailable)", ) + parser.add_argument( + "--dask-graph", + action="store_true", + help="visualize the pipeline's dask graph as SVG instead of rendering a plot", + ) for custom_arg in CUSTOM_ARGS: custom_arg.add_to(parser) diff --git a/tests/test_dask_graph.py b/tests/test_dask_graph.py new file mode 100644 index 0000000..5fe6938 --- /dev/null +++ b/tests/test_dask_graph.py @@ -0,0 +1,55 @@ +"""Regression tests for the structure of the dask graph produced by the particle +pipeline. The pipeline's downstream selects a small subset of columns; the +column-projection optimization in dask-expr should propagate that selection all +the way down to the per-file reads, so unused columns are never read from disk. + +These tests inspect the *optimized* expression tree and assert that read tasks +for unwanted columns are absent. They guard against regressions where an opaque +operation (e.g. an inline `map_partitions(lambda)`) blocks projection pushdown, +silently causing dead loads of every column in every file. +""" + +from conftest import _DATA_DIR + +from lib.config import CONFIG +from lib.parsing.parse import get_parsed_args + + +def _read_keys_for_columns(args_list: list[str], data_dir: str = "test-2d") -> list[str]: + """Optimize each dask collection produced by `args_list` and return + the set of per-column file-read task key strings in the optimized graph.""" + original = CONFIG.data_dir + CONFIG.data_dir = _DATA_DIR / data_dir + try: + args = get_parsed_args(args_list) + data = args.get_data() + collections = data.dask_collections() + assert collections, "expected particle pipeline to be dask-backed" + read_keys: list[str] = [] + for c in collections: + opt = c.optimize() if hasattr(c, "optimize") else c + for k in opt.__dask_graph__(): + key = k[0] if isinstance(k, tuple) else k + if isinstance(key, str) and "open_dataset" in key: + read_keys.append(key) + return read_keys + finally: + CONFIG.data_dir = original + + +def test_particle_load_projects_columns_to_reads(): + """`prt.i -v y py` only needs y and py; the optimized graph must not + contain read tasks for px, pz, w, x, or z.""" + read_keys = _read_keys_for_columns(["prt.i", "-v", "y", "py", "-q"]) + unwanted = ["px", "pz", "w", "x", "z"] + leaked = sorted({c for c in unwanted if any(f"-{c}-" in k for k in read_keys)}) + assert not leaked, f"unprojected columns still being read from disk: {leaked}" + + +def test_particle_load_projects_columns_in_binned_pipeline(): + """`--bin y py=16 -v y py` only needs y, py, and the weight column. The + optimized graph must not contain read tasks for px, pz, x, or z.""" + read_keys = _read_keys_for_columns(["prt.i", "--bin", "y", "py=16", "-v", "y", "py", "-q"]) + unwanted = ["px", "pz", "x", "z"] + leaked = sorted({c for c in unwanted if any(f"-{c}-" in k for k in read_keys)}) + assert not leaked, f"unprojected columns still being read from disk: {leaked}"