diff --git a/legate/core/context.py b/legate/core/context.py index 63c4beb1dc..069fca9b69 100644 --- a/legate/core/context.py +++ b/legate/core/context.py @@ -370,7 +370,13 @@ def issue_execution_fence(self, block: bool = False) -> None: """ self._runtime.issue_execution_fence(block=block) - def tree_reduce(self, task_id: int, store: Store, radix: int = 4) -> Store: + def tree_reduce( + self, + task_id: int, + store: Store, + radix: int = 4, + scalar_args: list[tuple[Any, Dtype]] = [], + ) -> Store: """ Performs a user-defined reduction by building a tree of reduction tasks. At each step, the reducer task gets up to ``radix`` input stores @@ -399,4 +405,6 @@ def tree_reduce(self, task_id: int, store: Store, radix: int = 4) -> Store: Store Store that contains reduction results """ - return self._runtime.tree_reduce(self, task_id, store, radix) + return self._runtime.tree_reduce( + self, task_id, store, radix, scalar_args + ) diff --git a/legate/core/operation.py b/legate/core/operation.py index 2483cadfb1..ea493fe523 100644 --- a/legate/core/operation.py +++ b/legate/core/operation.py @@ -1494,6 +1494,7 @@ def __init__( ) self._radix = radix self._task_id = task_id + self._scalar_args: list[tuple[Any, ty.Dtype]] = [] def add_input(self, store: Store) -> None: self._check_store(store) @@ -1508,6 +1509,9 @@ def add_output(self, store: Store) -> None: self._outputs.append(store) self._output_parts.append(partition) + def add_scalar_arg(self, value: Any, dtype: ty.Dtype) -> None: + self._scalar_args.append((value, dtype)) + def launch(self, strategy: Strategy) -> None: assert len(self._inputs) == 1 and len(self._outputs) == 1 @@ -1537,6 +1541,9 @@ def launch(self, strategy: Strategy) -> None: provenance=self.provenance, ) + for scalar, dtype in self._scalar_args: + launcher.add_scalar_arg(scalar, dtype) + if num_tasks > 1: for proj_fn in proj_fns: launcher.add_input( diff --git a/legate/core/runtime.py b/legate/core/runtime.py index 0aca17e5a6..96ca2f0712 100644 --- a/legate/core/runtime.py +++ b/legate/core/runtime.py @@ -1731,7 +1731,12 @@ def issue_fill( fill.execute() def tree_reduce( - self, context: Context, task_id: int, store: Store, radix: int = 4 + self, + context: Context, + task_id: int, + store: Store, + radix: int = 4, + scalar_args: list[tuple[Any, ty.Dtype]] = [], ) -> Store: """ Performs a user-defined reduction by building a tree of reduction @@ -1780,6 +1785,8 @@ def tree_reduce( ) task.add_input(store) task.add_output(result) + for scalar_arg in scalar_args: + task.add_scalar_arg(scalar_arg[0], scalar_arg[1]) task.execute() return result