Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
[project.optional-dependencies]
test = ["pytest>=8.0", "pytest-mpl>=0.17"]
dev = ["ruff>=0.6"]
dask-graph = ["graphviz>=0.20"]

[project.scripts]
psc-plot = "lib.cli:main"
Expand Down
44 changes: 44 additions & 0 deletions src/lib/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sys
import warnings
import webbrowser
from pathlib import Path

import dask
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -35,6 +37,41 @@ def _resolve_save_format(args: Args) -> SaveFormat | None:
return "gif"


def _run_dask_graph(args: Args) -> None:
data = args.get_data()

collections = data.dask_collections()
if not collections:
print(
f"error: --dask-graph requires dask-backed data; pipeline produced eager {type(data).__name__}",
file=sys.stderr,
)
sys.exit(1)

try:
import graphviz # noqa: F401
except ImportError:
print(
"error: --dask-graph requires the 'graphviz' package; install with `pip install -e \".[dask-graph]\"`",
file=sys.stderr,
)
sys.exit(1)

save_dir = args.save or Path.cwd()
save_dir.mkdir(exist_ok=True)
path = save_dir / f"{args.get_save_file_stem()}.daskgraph.svg"
# dask.visualize's optimize_graph flag only lowers legacy HLG collections
# (e.g. dask Arrays), not new-style Expr ones (dask DataFrames) — without
# pre-optimizing the latter, un-lowered nodes (e.g. Concat from dd.concat)
# fail with NotImplementedError in _layer.
collections = [c.optimize() if hasattr(c, "optimize") else c for c in collections]
dask.visualize(*collections, filename=str(path), optimize_graph=True)
print(f"wrote to {path}")

if args.show:
webbrowser.open(path.absolute().as_uri())


def main():
dask.config.set(num_workers=CONFIG.dask_num_workers)
if CONFIG.dask_scheduler == "distributed":
Expand All @@ -46,6 +83,13 @@ def main():
dask.config.set(scheduler=CONFIG.dask_scheduler)

args = parsing.get_parsed_args()

if args.dask_graph:
if args.save_format is not None:
warnings.warn("--save-format is ignored with --dask-graph")
_run_dask_graph(args)
return

# resolve format BEFORE applying pipeline in order to fail early
format = _resolve_save_format(args)

Expand Down
12 changes: 12 additions & 0 deletions src/lib/data/data_with_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def lower_bound(self, dim_name: str) -> float: ...
@abstractmethod
def upper_bound(self, dim_name: str) -> float: ...

@abstractmethod
def dask_collections(self) -> list: ...


@dataclass(kw_only=True, frozen=True)
class FieldMetadata(Metadata):
Expand Down Expand Up @@ -166,6 +169,9 @@ def var_bounds(self) -> tuple[float, float]:
active = self.active_data
return dask.compute(np.min(active), np.max(active))

def dask_collections(self) -> list:
return [da.data for da in self.data.data_vars.values() if dask.is_dask_collection(da.data)]


@dataclass(kw_only=True, frozen=True)
class ListMetadata(Metadata):
Expand Down Expand Up @@ -240,6 +246,9 @@ def upper_bound(self, dim_name) -> float:
cache[dim_name] = self.data[dim_name].max(skipna=True)
return cache[dim_name]

def dask_collections(self) -> list:
return []


class LazyList(List[dd.DataFrame]):
data: dd.DataFrame
Expand All @@ -266,3 +275,6 @@ def lower_bound(self, dim_name) -> float:

def upper_bound(self, dim_name) -> float:
return self.bounds(dim_name)[1]

def dask_collections(self) -> list:
return [self.data]
84 changes: 63 additions & 21 deletions src/lib/data/loaders/particle_bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import dask.dataframe as dd
import numpy as np
import pandas as pd
import xarray as xr

from lib.config import CONFIG
Expand All @@ -24,22 +25,46 @@ def _read_attrs(path: pathlib.Path) -> dict:
return {k: ds.attrs[k] for k in ds.attrs}


def _load_step_df(path: pathlib.Path, time: float) -> dd.DataFrame:
"""Open one BP step lazily and return a per-step dask DataFrame with a
constant `t` column. Drops the BP-assigned particle-dim index column.
def _peek_size(path: pathlib.Path) -> tuple[str, int]:
"""Return the file's particle-dim name and length without reading data."""
with xr.open_dataset(path) as ds:
return next((d, n) for d, n in ds.sizes.items() if n > 1)


Note: the `t` column is added via map_partitions rather than
dd.DataFrame.assign — the latter creates a broadcast-scalar column whose
`to_dask_array()` trips an IndexError in dask-expr's optimizer when the
dataframe came from xarray's to_dask_dataframe. map_partitions produces a
proper per-row column that survives the optimizer.
"""
with xr.open_dataset(path) as raw:
particle_dim = next(d for d, n in raw.sizes.items() if n > 1)
ds = xr.open_dataset(path, chunks={particle_dim: CONFIG.dask_chunk_size}).squeeze(drop=True)
df = ds.to_dask_dataframe().drop(columns=[particle_dim])
df = df.map_partitions(lambda p, t: p.assign(t=t), np.float64(time))
return df
def _build_meta(path: pathlib.Path) -> pd.DataFrame:
"""Build an empty pandas DataFrame whose columns/dtypes match a per-partition read."""
with xr.open_dataset(path) as ds:
particle_dim = next(d for d, n in ds.sizes.items() if n > 1)
dtypes = {var: ds[var].dtype for var in ds.data_vars if var != particle_dim}
meta = pd.DataFrame({var: pd.Series(dtype=dt) for var, dt in dtypes.items()})
meta["t"] = pd.Series(dtype=np.float64)
return meta


def _read_chunk(
path: pathlib.Path,
time: float,
particle_dim: str,
slice_obj: slice,
columns: list[str] | None = None,
) -> pd.DataFrame:
"""Read one chunk-slice of one BP file as a pandas DataFrame with a `t`
column. The `columns` keyword is populated by dask-expr's column-projection
optimizer; when supplied, only those variables are read from disk."""
ds = xr.open_dataset(path)
wanted_vars = [c for c in columns if c != "t" and c in ds.data_vars] if columns is not None else [v for v in ds.data_vars if v != particle_dim]
if wanted_vars:
sliced = ds[wanted_vars].isel({particle_dim: slice_obj}).squeeze(drop=True)
pdf = pd.DataFrame({var: np.asarray(sliced[var].values) for var in sliced.data_vars})
else:
start, stop, step = slice_obj.indices(ds.sizes[particle_dim])
n_rows = max(0, (stop - start + step - 1) // step)
pdf = pd.DataFrame(index=pd.RangeIndex(n_rows))
if columns is None or "t" in columns:
pdf["t"] = np.float64(time)
if columns is not None:
pdf = pdf[[c for c in columns if c in pdf.columns]]
return pdf


_SPECIES_KEY_RE = re.compile(r"^([a-zA-Z]+)([+-]*)(\d*)$")
Expand Down Expand Up @@ -87,14 +112,31 @@ def get_data(self) -> LazyList:
info = SpeciesInfo(self.species_key, display, q, m)
species_dict = {self.species_key: info}

dfs = [_load_step_df(_get_path(self.prefix, step), time) for step, time in zip(self.steps, times)]
df = dd.concat(dfs)

# Build per-partition iterables for dd.from_map, chunking each file
# along its particle dim. dd.from_map propagates downstream column
# projection into `_read_chunk` via its `columns` kwarg, so unused
# variables are never read from disk.
chunk_size = CONFIG.dask_chunk_size
paths: list[pathlib.Path] = []
step_times: list[float] = []
particle_dims: list[str] = []
slices: list[slice] = []
partition_ranges = []
offset = 0
for d in dfs:
partition_ranges.append((offset, offset + d.npartitions))
offset += d.npartitions
for step, time in zip(self.steps, times):
path = _get_path(self.prefix, step)
particle_dim, n = _peek_size(path)
n_chunks = max(1, (n + chunk_size - 1) // chunk_size)
partition_ranges.append((offset, offset + n_chunks))
offset += n_chunks
for i in range(n_chunks):
paths.append(path)
step_times.append(float(time))
particle_dims.append(particle_dim)
slices.append(slice(i * chunk_size, (i + 1) * chunk_size))

meta = _build_meta(paths[0])
df = dd.from_map(_read_chunk, paths, step_times, particle_dims, slices, meta=meta)

corners = np.asarray(head["corner"])
lengths = np.asarray(head["length"])
Expand Down
9 changes: 7 additions & 2 deletions src/lib/parsing/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from lib.data.adaptor import Adaptor
from lib.data.compile import compile_source
from lib.data.data_source import DataSource
from lib.data.data_with_attrs import DataWithAttrs
from lib.plotting.get_plot import get_plot
from lib.plotting.hook import Hook
from lib.plotting.plot import Plot
Expand All @@ -18,10 +19,14 @@ class Args(argparse.Namespace):
show: bool
save: Path | None
save_format: str | None
dask_graph: bool

def get_animation(self) -> Plot:
def get_data(self) -> DataWithAttrs:
source = compile_source(self.loader, self.adaptors)
data = source.get_data()
return source.get_data()

def get_animation(self) -> Plot:
data = self.get_data()

plot = get_plot(data)

Expand Down
5 changes: 5 additions & 0 deletions src/lib/parsing/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def _get_parser(prefixes: Iterable[str]) -> argparse.ArgumentParser:
default=None,
help="format for saved animations (default: mp4, falls back to gif if ffmpeg unavailable)",
)
parser.add_argument(
"--dask-graph",
action="store_true",
help="visualize the pipeline's dask graph as SVG instead of rendering a plot",
)

for custom_arg in CUSTOM_ARGS:
custom_arg.add_to(parser)
Expand Down
55 changes: 55 additions & 0 deletions tests/test_dask_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Regression tests for the structure of the dask graph produced by the particle
pipeline. The pipeline's downstream selects a small subset of columns; the
column-projection optimization in dask-expr should propagate that selection all
the way down to the per-file reads, so unused columns are never read from disk.

These tests inspect the *optimized* expression tree and assert that read tasks
for unwanted columns are absent. They guard against regressions where an opaque
operation (e.g. an inline `map_partitions(lambda)`) blocks projection pushdown,
silently causing dead loads of every column in every file.
"""

from conftest import _DATA_DIR

from lib.config import CONFIG
from lib.parsing.parse import get_parsed_args


def _read_keys_for_columns(args_list: list[str], data_dir: str = "test-2d") -> list[str]:
"""Optimize each dask collection produced by `args_list` and return
the set of per-column file-read task key strings in the optimized graph."""
original = CONFIG.data_dir
CONFIG.data_dir = _DATA_DIR / data_dir
try:
args = get_parsed_args(args_list)
data = args.get_data()
collections = data.dask_collections()
assert collections, "expected particle pipeline to be dask-backed"
read_keys: list[str] = []
for c in collections:
opt = c.optimize() if hasattr(c, "optimize") else c
for k in opt.__dask_graph__():
key = k[0] if isinstance(k, tuple) else k
if isinstance(key, str) and "open_dataset" in key:
read_keys.append(key)
return read_keys
finally:
CONFIG.data_dir = original


def test_particle_load_projects_columns_to_reads():
"""`prt.i -v y py` only needs y and py; the optimized graph must not
contain read tasks for px, pz, w, x, or z."""
read_keys = _read_keys_for_columns(["prt.i", "-v", "y", "py", "-q"])
unwanted = ["px", "pz", "w", "x", "z"]
leaked = sorted({c for c in unwanted if any(f"-{c}-" in k for k in read_keys)})
assert not leaked, f"unprojected columns still being read from disk: {leaked}"


def test_particle_load_projects_columns_in_binned_pipeline():
"""`--bin y py=16 -v y py` only needs y, py, and the weight column. The
optimized graph must not contain read tasks for px, pz, x, or z."""
read_keys = _read_keys_for_columns(["prt.i", "--bin", "y", "py=16", "-v", "y", "py", "-q"])
unwanted = ["px", "pz", "x", "z"]
leaked = sorted({c for c in unwanted if any(f"-{c}-" in k for k in read_keys)})
assert not leaked, f"unprojected columns still being read from disk: {leaked}"
Loading