Skip to content

Commit d58a485

Browse files
committed
Make rapidsmpf py_executor configurable
This adds a new configuration option to control the number of threads passed to the ThreadPoolExecutor used as rapidsmpf's `py_executor`. This pairs with rapidsai/rapidsmpf#858, which starts using the user-provided `py_executor` for `asyncio.to_thread` calls.
1 parent e323fd8 commit d58a485

3 files changed

Lines changed: 48 additions & 1 deletion

File tree

python/cudf_polars/cudf_polars/experimental/rapidsmpf/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,10 @@ def evaluate_pipeline(
250250
)
251251

252252
# Run the network
253-
executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="cpse")
253+
executor = ThreadPoolExecutor(
254+
max_workers=config_options.executor.rapidsmpf_py_executor_max_workers,
255+
thread_name_prefix="cpse",
256+
)
254257
run_streaming_pipeline(nodes=nodes, py_executor=executor)
255258

256259
# Extract/return the concatenated result.

python/cudf_polars/cudf_polars/utils/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,10 @@ class StreamingExecutor:
704704
or use regular pageable host memory. Pinned host memory offers higher
705705
bandwidth and lower latency for device to host transfers compared to
706706
regular pageable host memory.
707+
rapidsmpf_py_executor_max_workers
708+
Maximum number of workers for the Python ThreadPoolExecutor used by
709+
the rapidsmpf runtime. Default is None, which uses ThreadPoolExecutor's
710+
default behavior. This option is only used by the "rapidsmpf" runtime.
707711
708712
Notes
709713
-----
@@ -812,6 +816,11 @@ class StreamingExecutor:
812816
f"{_env_prefix}__SPILL_TO_PINNED_MEMORY", bool, default=False
813817
)
814818
)
819+
rapidsmpf_py_executor_max_workers: int | None = dataclasses.field(
820+
default_factory=_make_default_factory(
821+
f"{_env_prefix}__RAPIDSMPF_PY_EXECUTOR_MAX_WORKERS", int, default=None
822+
)
823+
)
815824

816825
def __post_init__(self) -> None: # noqa: D105
817826
# Check for rapidsmpf runtime
@@ -959,6 +968,8 @@ def __post_init__(self) -> None: # noqa: D105
959968
raise TypeError("max_io_threads must be an int")
960969
if not isinstance(self.spill_to_pinned_memory, bool):
961970
raise TypeError("spill_to_pinned_memory must be bool")
971+
if not isinstance(self.rapidsmpf_py_executor_max_workers, (int, type(None))):
972+
raise TypeError("rapidsmpf_py_executor_max_workers must be int or None")
962973

963974
# RapidsMPF spill is only supported for distributed clusters for now.
964975
# This is because the spilling API is still within the RMPF-Dask integration.

python/cudf_polars/tests/test_config.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ def test_validate_shuffle_insertion_method() -> None:
443443
"client_device_threshold",
444444
"max_io_threads",
445445
"spill_to_pinned_memory",
446+
"rapidsmpf_py_executor_max_workers",
446447
],
447448
)
448449
def test_validate_streaming_executor_options(option: str) -> None:
@@ -1010,3 +1011,35 @@ def test_rapidsmpf_distributed_warns(monkeypatch: pytest.MonkeyPatch) -> None:
10101011
},
10111012
)
10121013
)
1014+
1015+
1016+
def test_rapidsmpf_py_executor_max_workers_default() -> None:
1017+
config = ConfigOptions.from_polars_engine(
1018+
pl.GPUEngine(
1019+
executor="streaming",
1020+
)
1021+
)
1022+
assert config.executor.name == "streaming"
1023+
assert config.executor.rapidsmpf_py_executor_max_workers is None
1024+
1025+
1026+
def test_rapidsmpf_py_executor_max_workers_from_executor_options() -> None:
1027+
config = ConfigOptions.from_polars_engine(
1028+
pl.GPUEngine(
1029+
executor="streaming",
1030+
executor_options={"rapidsmpf_py_executor_max_workers": 4},
1031+
)
1032+
)
1033+
assert config.executor.name == "streaming"
1034+
assert config.executor.rapidsmpf_py_executor_max_workers == 4
1035+
1036+
1037+
def test_rapidsmpf_py_executor_max_workers_from_env(
1038+
monkeypatch: pytest.MonkeyPatch,
1039+
) -> None:
1040+
with monkeypatch.context() as m:
1041+
m.setenv("CUDF_POLARS__EXECUTOR__RAPIDSMPF_PY_EXECUTOR_MAX_WORKERS", "8")
1042+
engine = pl.GPUEngine(executor="streaming")
1043+
config = ConfigOptions.from_polars_engine(engine)
1044+
assert config.executor.name == "streaming"
1045+
assert config.executor.rapidsmpf_py_executor_max_workers == 8

0 commit comments

Comments
 (0)