Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 35 additions & 13 deletions python/cudf_polars/cudf_polars/experimental/benchmarks/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

"""Utility functions/classes for running the PDS-H and PDS-DS benchmarks."""
Expand Down Expand Up @@ -521,16 +521,27 @@ def initialize_dask_cluster(run_config: RunConfig, args: argparse.Namespace): #
from dask_cuda import LocalCUDACluster
from distributed import Client

kwargs = {
# Base cluster configuration
kwargs: dict[str, object] = {
"n_workers": run_config.n_workers,
"dashboard_address": ":8585",
"protocol": args.protocol,
"rmm_pool_size": args.rmm_pool_size,
"rmm_async": args.rmm_async,
"rmm_release_threshold": args.rmm_release_threshold,
"threads_per_worker": run_config.threads,
}

# Only let rapidsmpf handle RMM configuration when rapidsmpf must be installed.
use_rapidsmpf_rmm = (
run_config.shuffle == "rapidsmpf" or run_config.runtime == "rapidsmpf"
)
if not use_rapidsmpf_rmm:
kwargs.update(
{
"rmm_pool_size": args.rmm_pool_size,
"rmm_async": args.rmm_async,
"rmm_release_threshold": args.rmm_release_threshold,
}
)

# Avoid UVM in distributed cluster
client = Client(LocalCUDACluster(**kwargs))
client.wait_for_workers(run_config.n_workers)
Expand All @@ -540,16 +551,27 @@ def initialize_dask_cluster(run_config: RunConfig, args: argparse.Namespace): #
from rapidsmpf.config import Options
from rapidsmpf.integrations.dask import bootstrap_dask_cluster

# Build rapidsmpf options including RMM configuration
rapidsmpf_options = {
"dask_spill_device": str(run_config.spill_device),
"dask_statistics": str(args.rapidsmpf_dask_statistics),
"dask_print_statistics": str(args.rapidsmpf_print_statistics),
"oom_protection": str(args.rapidsmpf_oom_protection),
}

# Let rapidsmpf handle RMM configuration (supports fractions like "0.5")
if use_rapidsmpf_rmm:
if args.rmm_pool_size is not None:
rapidsmpf_options["dask_rmm_pool_size"] = str(args.rmm_pool_size)
rapidsmpf_options["dask_rmm_async"] = str(args.rmm_async).lower()
if args.rmm_release_threshold is not None:
rapidsmpf_options["dask_rmm_release_threshold"] = str(
args.rmm_release_threshold
)

bootstrap_dask_cluster(
client,
options=Options(
{
"dask_spill_device": str(run_config.spill_device),
"dask_statistics": str(args.rapidsmpf_dask_statistics),
"dask_print_statistics": str(args.rapidsmpf_print_statistics),
"oom_protection": str(args.rapidsmpf_oom_protection),
}
),
options=Options(rapidsmpf_options),
)
# Setting this globally makes the peak statistics not meaningful
# across queries / iterations. But doing it per query isn't worth
Expand Down