diff --git a/cunumeric/__init__.py b/cunumeric/__init__.py index 13c8504b8b..f447cec850 100644 --- a/cunumeric/__init__.py +++ b/cunumeric/__init__.py @@ -36,6 +36,7 @@ from cunumeric.logic import * from cunumeric.window import bartlett, blackman, hamming, hanning, kaiser from cunumeric.coverage import clone_module +from cunumeric.vectorize import vectorize clone_module(_np, globals(), maybe_convert_to_np_ndarray) diff --git a/cunumeric/config.py b/cunumeric/config.py index 21a7a68e58..abcd2f3eca 100644 --- a/cunumeric/config.py +++ b/cunumeric/config.py @@ -143,9 +143,11 @@ class _CunumericSharedLib: CUNUMERIC_CONVERT_NAN_PROD: int CUNUMERIC_CONVERT_NAN_SUM: int CUNUMERIC_CONVOLVE: int + CUNUMERIC_CREATE_CU_KERNEL: int CUNUMERIC_DIAG: int CUNUMERIC_DOT: int CUNUMERIC_EYE: int + CUNUMERIC_EVAL_UDF: int CUNUMERIC_FFT: int CUNUMERIC_FFT_C2C: int CUNUMERIC_FFT_C2R: int @@ -332,9 +334,11 @@ class CuNumericOpCode(IntEnum): CONTRACT = _cunumeric.CUNUMERIC_CONTRACT CONVERT = _cunumeric.CUNUMERIC_CONVERT CONVOLVE = _cunumeric.CUNUMERIC_CONVOLVE + CREATE_CU_KERNEL = _cunumeric.CUNUMERIC_CREATE_CU_KERNEL DIAG = _cunumeric.CUNUMERIC_DIAG DOT = _cunumeric.CUNUMERIC_DOT EYE = _cunumeric.CUNUMERIC_EYE + EVAL_UDF = _cunumeric.CUNUMERIC_EVAL_UDF FFT = _cunumeric.CUNUMERIC_FFT FILL = _cunumeric.CUNUMERIC_FILL FLIP = _cunumeric.CUNUMERIC_FLIP diff --git a/cunumeric/utils.py b/cunumeric/utils.py index 0586bb8f36..ddd59f50aa 100644 --- a/cunumeric/utils.py +++ b/cunumeric/utils.py @@ -22,6 +22,7 @@ import legate.core.types as ty import numpy as np +import pyarrow as pa from .types import NdShape @@ -42,6 +43,25 @@ np.dtype(np.complex128): ty.complex128, } +CUNUMERIC_TYPE_MAP = { + bool: ty.bool_, + int: ty.int64, + float: ty.float64, + complex: ty.complex128, + pa.bool_: ty.bool_, + pa.int8: ty.int8, + pa.int16: ty.int16, + pa.int32: ty.int32, + pa.int64: ty.int64, # np.int is int + pa.uint8: ty.uint8, + pa.uint16: ty.uint16, + pa.uint32: ty.uint32, + pa.uint64: ty.uint64, # np.uint is np.uint64 + pa.float16: ty.float16, + pa.float32: ty.float32, + pa.float64: ty.float64, +} + def to_core_dtype(dtype: Union[str, np.dtype[Any]]) -> Optional[ty.Dtype]: return SUPPORTED_DTYPES.get(np.dtype(dtype)) @@ -94,6 +114,12 @@ def find_last_user_frames(top_only: bool = True) -> str: return "|".join(get_line_number_from_frame(f) for f in frames) +def convert_to_cunumeric_dtype(dtype: Any) -> Any: + if dtype in CUNUMERIC_TYPE_MAP: + return CUNUMERIC_TYPE_MAP[dtype] + raise TypeError("dtype is not supported") + + def calculate_volume(shape: NdShape) -> int: if len(shape) == 0: return 0 diff --git a/cunumeric/vectorize.py b/cunumeric/vectorize.py new file mode 100644 index 0000000000..c3f691d164 --- /dev/null +++ b/cunumeric/vectorize.py @@ -0,0 +1,640 @@ +# Copyright 2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import inspect +import re +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import legate.core.types as ty +import numba +import numba.core.ccallback +import numpy as np +from legate.core import Rect, track_provenance + +from cunumeric.runtime import runtime + +from .array import convert_to_cunumeric_ndarray +from .config import CuNumericOpCode +from .module import full +from .utils import convert_to_cunumeric_dtype + +_EXTERNAL_REFERENCE_PREFIX = "__extern_ref__" +_MASK_VAR = "__mask__" +_SIZE_VAR = "__size__" +_LOOP_VAR = "__i__" +_ARGS_VAR = "__args__" +_DIM_VAR = "__dim__" +_STRIDES_VAR = "__strides__" +_PITCHES_VAR = "__pitches__" + + +class vectorize: + def __init__( + self, + pyfunc: Callable[[Any], Any], + otypes: Optional[Union[str, list[Any]]] = None, + doc: Optional[str] = None, + excluded: Optional[set[Any]] = None, + cache: bool = False, + signature: Optional[str] = None, + ) -> None: + """ + vectorize(pyfunc, otypes=None, doc=None, excluded=None, cache=False, + signature=None) + + Generalized function class. + Define a vectorized function which takes a nested sequence of + objects or numpy arrays as inputs and returns a single numpy array + or a tuple of numpy arrays. + User defined pyfunction will be executed in a single cuNumeric task + over a set of arguments. + The data type of the output of `vectorized` is determined by calling + the function with the first element of the input. This can be avoided + by specifying the `otypes` argument. + WARNING: when running with OpenMP back-end, "vectorize" will fall-back + to the serial CPU implementation + + Parameters + ---------- + pyfunc : callable + A python function or method. + otypes : str or list of dtypes, optional + The output data type. It must be specified as either a string of + typecode characters or a list of data type specifiers. There should + be one data type specifier for each output. + WARNING: cuNumeric currently requires all output types to be the + same + doc : str, optional + The docstring for the function. If None, the docstring will be the + ``pyfunc.__doc__``. + excluded : set, optional + Set of strings or integers representing the positional or keyword + arguments for which the function will not be vectorized. + These will be passed directly to `pyfunc` unmodified. + WARNING: cuNumeric doesn't support this argument at the moment + cache : bool, optional + If `True`, then cache the first function call that generates C fun- + ction or CUDA kernel. We recommend enabling caching in cuNumeric + for better performance, when possible. + WARNING: in the case when cache=True, cuNumeric will parse function + signature and create C function or CUDA kernel only once. This + means that types of arguments passed to the vectorized function + (arrays, scalars etc) should be the same each time we call it. + signature : string, optional + Generalized universal function signature, e.g., ``(m,n),(n)->(m)`` + for vectorized matrix-vector multiplication. If provided, + ``pyfunc`` will be called with (and expected to return) + arrays with shapes given by the size of corresponding core + dimensions. By default, ``pyfunc`` is assumed to take scalars + as input and output. + WARNING: cuNumeric doesn't support this argument at the moment + + Returns + ------- + vectorized : callable + Vectorized function. + + See Also + -------- + numpy.vectorize + + Availability + -------- + Multiple GPUs, Multiple CPUs + """ + + self._pyfunc = pyfunc + self._otypes: Optional[tuple[Any]] = None + self._cache: bool = cache + self._numba_func: Callable[[Any], Any] + self._cpu_func: numba.core.ccallback.CFunc + self._gpu_func: tuple[Any] + self._args: List[Any] = [] + self._scalar_args: List[Any] = [] + self._scalar_idxs: List[int] = [] + self._scalar_names: List[str] = [] + self._arg_names: List[str] = [] + self._context = runtime.legate_context + self._created: bool = False + self._func_body: List[str] = [] + + if doc is None: + self.__doc__ = pyfunc.__doc__ + else: + self.__doc__ = doc + + self._return_names = self._get_return_arguments() + self._num_outputs: int = len(self._return_names) + self._return_args: List[Any] = [] + self._output_dtype: Optional[np.dtype[Any]] = None + self._cached_dtype: Optional[np.dtype[Any]] = None + self._cached_scalar_types: List[Any] = [] + + if otypes is not None: + if self._num_outputs != len(otypes): + raise ValueError( + "number of types in otypes is not consistent" + " with the number of return values defined in pyfunc" + ) + if len(otypes) > 1: + for t in otypes: + if t != otypes[0]: + raise NotImplementedError( + "cuNumeric doesn't support variable types" + " in otypes" + ) + self._output_dtype = np.dtype(otypes[0]) + + # FIXME + if excluded is not None: + raise NotImplementedError( + "excluded variables are not supported yet" + ) + + # FIXME + if signature is not None: + raise NotImplementedError( + "signature variable is not supported yet" + ) + + def _get_func_body(self, func: Callable[[Any], Any]) -> List[str]: + """Using the magic method __doc__, we KNOW the size of the docstring. + We then, just subtract this from the total length of the function + """ + lines_to_skip = 0 + if func.__doc__ is not None and len(func.__doc__.split("\n")) > 0: + lines_to_skip = len(func.__doc__.split("\n")) + + lines = inspect.getsourcelines(func)[0] # type ignore + + return_lines = [] + for i in range(lines_to_skip + 1, len(lines)): + return_lines.append(lines[i].rstrip()) + return return_lines + + def _get_return_arguments(self) -> List[str]: + """ + Returns the list of names for return arrays/values + """ + self._func_body = self._get_func_body(self._pyfunc) + return_names = [] + for ln in self._func_body: + if "return" in ln: + ln = ln.replace("return", "") + ln = ln.replace(" ", "") + return_names += ln.split(",") + # we check if return statement has any special characters since + # we don't support cases like "return a+b" + for n in return_names: + regex = re.compile("[^A-Za-z0-9]") + res = regex.findall(n) + if len(res) > 0: + raise NotImplementedError( + " CuNumeric doesn't support special " + "characters in the return statement of the " + "user-defined function " + ) + return return_names + + def _replace_name( + self, name: str, _LOOP_VAR: str, is_gpu: bool = False + ) -> str: + """ + add indices to the names of input/output arrays in the function body + """ + if (name in self._arg_names) or (name in self._return_names): + return f"{name}[int({_LOOP_VAR})]" + else: + if is_gpu or ((not is_gpu) and not (name in self._scalar_names)): + return f"{name}" + else: + return f"{name}[0]" + + def _build_gpu_function(self) -> Any: + funcid = f"vectorized_{self._pyfunc.__name__}" + + # Preamble + lines = ["from numba import cuda"] + # we add math and numpy so user-defined functions can use them + lines.append("import math") + lines.append("import numpy") + + # Signature + args = ( + self._return_names + + self._arg_names + + self._scalar_names + + [_SIZE_VAR] + + [_DIM_VAR] + + [_PITCHES_VAR] + + [_STRIDES_VAR] + ) + + lines.append("def {}({}):".format(funcid, ",".join(args))) + # Initialize the index variable and return immediately + # when it exceeds the data size + # we compute index for sparse data access when using Legion's + # pointer. + # a[x][y][z]=a[x*strides[0] + y*strides[1] + z*strides[2]] + loop_lines = f"""\ + local_i = cuda.grid(1) + if local_i >= {_SIZE_VAR}: + return + {_LOOP_VAR}:int = 0 + for p in range({_DIM_VAR}-1): + x=int(local_i/{_PITCHES_VAR}[p]) + local_i = int(local_i%{_PITCHES_VAR}[p]) + {_LOOP_VAR}+=int(x*{_STRIDES_VAR}[p]) + {_LOOP_VAR}+=int(local_i*{_STRIDES_VAR}[{_DIM_VAR}-1]) + """ + lines += loop_lines.split("\n") + + # this function is used to replace all array names with array[i] + def _lift_to_array_access(m: Any) -> str: + return self._replace_name(m.group(0), _LOOP_VAR, True) + + # kernel body + lines_old = self._func_body + for line in lines_old: + if not ("return" in line): + l_new = re.sub(r"[_a-zA-Z]\w*", _lift_to_array_access, line) + lines.append(l_new) + + # Evaluate the string to get the Python function + body = "\n".join(lines) + glbs: Dict[str, Any] = {} + exec(body, glbs) + return glbs[funcid] + + def _build_cpu_function(self) -> Callable[[Any], Any]: + funcid = f"vectorized_{self._pyfunc.__name__}" + + # Preamble + lines = ["from numba import carray, types"] + # we add math and numpy so user-defined functions can use them + lines.append("import math") + lines.append("import numpy") + + # Signature + lines.append( + f"def {funcid}({_ARGS_VAR},{_SIZE_VAR}, " + f"{_DIM_VAR}, {_PITCHES_VAR}, {_STRIDES_VAR}):" + ) + + # Unpack kernel arguments + def _emit_assignment( + var: Any, idx: int, sz: Any, ty: np.dtype[Any] + ) -> None: + lines.append( + f" {var} = carray({ _ARGS_VAR}[{idx}], {sz}, types.{ty})" + ) + + # define pyfunc arguments as carrays + arg_idx = 0 + for count, a in enumerate(self._return_args): + type_a = a.dtype + _emit_assignment( + self._return_names[count], arg_idx, _SIZE_VAR, type_a + ) + arg_idx += 1 + for count, a in enumerate(self._args): + type_a = a.dtype + _emit_assignment( + self._arg_names[count], arg_idx, _SIZE_VAR, type_a + ) + arg_idx += 1 + for count, a in enumerate(self._scalar_args): + scalar_type = np.dtype(type(a).__name__) + _emit_assignment( + self._scalar_names[count], arg_idx, 1, scalar_type + ) + arg_idx += 1 + + # Initialize the index variable and return immediately + # when it exceeds the data size + # we compute index for sparse data access when using Legion's + # pointer. + # a[x][y][z]=a[x*strides[0] + y*strides[1] + z*strides[2]] + loop_lines = f"""\ + for local_i in range({_SIZE_VAR}): + {_LOOP_VAR}:int = 0 + j:int = local_i + for p in range({_DIM_VAR}-1): + x=int(j/{_PITCHES_VAR}[p]) + j = int(j%{_PITCHES_VAR}[p]) + {_LOOP_VAR}+=int(x*{_STRIDES_VAR}[p]) + {_LOOP_VAR}+=int(j*{_STRIDES_VAR}[{_DIM_VAR}-1]) + """ + lines += loop_lines.split("\n") + + lines_old = self._func_body + + # Kernel body + def _lift_to_array_access(m: Any) -> str: + return self._replace_name(m.group(0), _LOOP_VAR) + + for line in lines_old: + if not ("return" in line): + l_new = re.sub(r"[_a-zA-Z]\w*", _lift_to_array_access, line) + lines.append(" " + l_new) + + # Evaluate the string to get the Python function + body = "\n".join(lines) + glbs: Dict[str, Any] = {} + exec(body, glbs) + return glbs[funcid] + + def _get_numba_types(self, need_pointer: bool = True) -> list[Any]: + types = [] + for arg in self._return_args + self._args: + type_a = arg.dtype + type_a = str(type_a) if type_a != bool else "int8" + type_a = getattr(numba.core.types, type_a) + type_a = numba.core.types.CPointer(type_a) + types.append(type_a) + for arg in self._scalar_args: + type_a = np.dtype(type(arg).__name__) + type_a = str(type_a) if type_a != bool else "int8" + type_a = getattr(numba.core.types, type_a) + types.append(type_a) + return types + + def _compile_func_gpu(self) -> tuple[Any]: + types = self._get_numba_types() + arg_types = ( + types + + [numba.core.types.uint64] + + [numba.core.types.uint64] + + [numba.core.types.CPointer(numba.core.types.uint64)] + + [numba.core.types.CPointer(numba.core.types.uint64)] + ) + sig = (*arg_types,) + + cuda_arch = numba.cuda.get_current_device().compute_capability + return numba.cuda.compile_ptx(self._numba_func, sig, cc=cuda_arch) + + def _compile_func_cpu(self) -> numba.core.ccallback.CFunc: + sig = numba.core.types.void( # type: ignore + numba.types.CPointer(numba.types.voidptr), + numba.core.types.uint64, + numba.core.types.uint64, + numba.core.types.CPointer(numba.core.types.uint64), + numba.core.types.CPointer(numba.core.types.uint64), + ) + + return numba.cfunc(sig)(self._numba_func) + + def _create_cuda_kernel(self, num_gpus: int) -> None: + # create CUDA kernel + launch_domain = Rect(lo=(0,), hi=(num_gpus,)) + kernel_task = self._context.create_manual_task( + CuNumericOpCode.CREATE_CU_KERNEL, + launch_domain=launch_domain, + ) + ptx_hash = hash(self._gpu_func[0]) + kernel_task.add_scalar_arg(ptx_hash, ty.int64) + kernel_task.add_scalar_arg(self._gpu_func[0], ty.string) + kernel_task.execute() + # we want to make sure EVAL_UDF function is not executed before + # CUDA kernel is created + self._context.issue_execution_fence(block=True) + + # task has finished by the time we set self._created to True + if self._cache: + self._created = True + + @track_provenance() + def _execute(self, is_gpu: bool, num_gpus: int = 0) -> None: + if is_gpu and not self._created: + self._create_cuda_kernel(num_gpus) + + task = self._context.create_auto_task(CuNumericOpCode.EVAL_UDF) + task.add_scalar_arg(self._num_outputs, ty.uint32) # N of outputs + task.add_scalar_arg( + len(self._scalar_args), ty.uint32 + ) # N of scalar_args + + # add all scalar arguments first + for a in self._scalar_args: + dtype = convert_to_cunumeric_dtype(type(a)) + task.add_scalar_arg(a, dtype) + + num_args = len(self._args) + # add return arguments with RW permissions + first_array = None + if self._num_outputs > 0: + first_array = runtime.to_deferred_array( + self._return_args[0]._thunk + ) + task.add_input(first_array.base) + task.add_output(first_array.base) + + for i in range(1, self._num_outputs): + a_tmp = runtime.to_deferred_array(self._return_args[i]._thunk) + a_tmp_base = a_tmp.base + task.add_input(a_tmp_base) + task.add_output(a_tmp_base) + task.add_alignment(first_array.base, a_tmp_base) + + # add array arguments with read-only permissions + if num_args > 0: + start = 0 + if first_array is None: + first_array = runtime.to_deferred_array(self._args[0]._thunk) + task.add_input(first_array.base) + start = 1 + for i in range(start, num_args): + a_tmp = runtime.to_deferred_array(self._args[i]._thunk) + a_tmp_base = a_tmp.base + task.add_input(a_tmp_base) + task.add_alignment(first_array.base, a_tmp_base) + + if is_gpu: + ptx_hash = hash(self._gpu_func[0]) + task.add_scalar_arg(ptx_hash, ty.int64) + else: + task.add_scalar_arg( + self._cpu_func.address, ty.uint64 + ) # type : ignore + task.execute() + + def _filter_arguments_and_check(self) -> None: + # this method will filter return and input arguments + # it will also check shape and type of the arguments + + output_shape: Tuple[int] = (-1,) + output_dtype = self._output_dtype + self._return_args.clear() + + # if output type is not specified, we need to decide + # which one to use + # we also want to choose the shape for output array + + # check if output variable is in input arguments - > + # then use it's dtype and shape + for r in self._return_names: + if r in self._arg_names: + idx = self._arg_names.index(r) + if output_dtype is None: + output_dtype = self._args[idx].dtype + if output_shape == (-1,): + output_shape = self._args[idx].shape + break + + # the case if we didn't find output argument in input argnames + if output_shape == (-1,): + for r in self._return_names: + if r in self._scalar_names: + idx = self._scalar_names.index(r) + if output_dtype is None: + output_dtype = np.dtype(type(self._scalar_args[idx])) + output_shape = (1,) + break + + if self._cache and not (self._cached_dtype is None): + if self._cached_dtype != output_dtype: + raise TypeError( + "types of the arguments should stay the same" + " for each invocation of the vectorize object" + ) + elif self._cache: + self._cached_dtype = output_dtype + + # FIXME + # we could find common type of input arguments here and + # broadcasted shapes + if self._num_outputs > 0 and output_dtype is None: + raise ValueError("Unable to choose output dtype") + if self._num_outputs > 0 and output_shape is None: + raise ValueError("Unable to choose output shape") + + # filing the list of return arguments + # check if there are return argnames in input argnames, + # if not, create a new array + for r in self._return_names: + if r in self._arg_names: + idx = self._arg_names.index(r) + if self._args[idx].shape != output_shape: + raise ValueError( + "all output arrays should have the same shape" + ) + if output_dtype != self._args[idx].dtype: + runtime.warn( + "converting input array to output types in user func ", + category=RuntimeWarning, + ) + self._args[idx] = self._args[idx].astype(output_dtype) + self._return_args.append(self._args[idx]) + self._args.remove(self._args[idx]) + self._arg_names.remove(r) + elif r in self._scalar_names: + idx = self._scalar_names.index(r) + if output_shape != (1,): + raise ValueError( + "all output arrays should have the same shape" + ) + self._return_args.append( + full(output_shape, self._scalar_args[idx], output_dtype) + ) + self._scalar_args.remove(self._scalar_args[idx]) + self._scalar_names.remove(r) + else: + # create array and add it to the list of return_args + tmp_ret = full(output_shape, 0, output_dtype) + self._return_args.append(tmp_ret) + + # check types and shapes + if len(self._args) > 0: + for count, a in enumerate(self._args): + if output_dtype != a.dtype: + runtime.warn( + "converting input array to output types in user func ", + category=RuntimeWarning, + ) + self._args[count] = self._args[count].astype(output_dtype) + # FIXME broadcast shapes + if output_shape != self._args[count].shape: + raise ValueError( + "cuNumeric doesn't support " + "different shapes for arrays in " + "user function passed to vectorize" + ) + + def __call__(self, *args: Any, **kwargs: Any) -> Union[Any, Tuple[Any]]: + # each time we call `vectorize` on a pyfunc we need to clear + # these lists to support different types of arguments passed + self._scalar_args.clear() + self._scalar_idxs.clear() + self._args.clear() + self._arg_names.clear() + self._scalar_names.clear() + + scalar_idx = 0 + for i, arg in enumerate(args): + if arg is None: + raise ValueError( + "None is not supported in user function " + "passed to cunumeric.vectorize" + ) + elif np.ndim(arg) == 0: + if self._cache and not self._created: + self._cached_scalar_types.append(type(arg)) + elif self._cache: + if self._cached_scalar_types[scalar_idx] != type(arg): + raise TypeError( + "Input arguments to vectorized function should" + " have consistent types for each invocation" + ) + self._scalar_args.append(arg) + self._scalar_idxs.append(i) + scalar_idx += 1 + else: + # we need to make a copy of original array to match numpy + self._args.append(convert_to_cunumeric_ndarray(arg.copy())) + + # first fill arrays to argnames, then scalars: + for i, k in enumerate(inspect.signature(self._pyfunc).parameters): + if not (i in self._scalar_idxs): + self._arg_names.append(k) + + for i, k in enumerate(inspect.signature(self._pyfunc).parameters): + if i in self._scalar_idxs: + self._scalar_names.append(k) + + if len(kwargs) > 0: + raise NotImplementedError( + "kwargs are not supported in user functions" + ) + + self._filter_arguments_and_check() + + if runtime.num_gpus > 0: + if not self._created: + self._numba_func = self._build_gpu_function() + self._gpu_func = self._compile_func_gpu() + self._execute(True, runtime.num_gpus) + else: + if not self._created: + self._numba_func = self._build_cpu_function() + self._cpu_func = self._compile_func_cpu() + if self._cache: + self._created = True + self._execute(False) + + if len(self._return_args) == 1: + return self._return_args[0] + if len(self._return_args) > 1: + return tuple(self._return_args) + return -1 diff --git a/cunumeric_cpp.cmake b/cunumeric_cpp.cmake index dd8a60f7e2..990174ac03 100644 --- a/cunumeric_cpp.cmake +++ b/cunumeric_cpp.cmake @@ -160,6 +160,8 @@ list(APPEND cunumeric_SOURCES src/cunumeric/mapper.cc src/cunumeric/cephes/chbevl.cc src/cunumeric/cephes/i0.cc + src/cunumeric/vectorize/eval_udf.cc + src/cunumeric/vectorize/create_cu_kernel.cc ) if(Legion_USE_OpenMP) @@ -205,6 +207,8 @@ if(Legion_USE_OpenMP) src/cunumeric/stat/bincount_omp.cc src/cunumeric/convolution/convolve_omp.cc src/cunumeric/transform/flip_omp.cc + src/cunumeric/vectorize/eval_udf_omp.cc + src/cunumeric/vectorize/create_cu_kernel_omp.cc ) endif() @@ -256,6 +260,8 @@ if(Legion_USE_CUDA) src/cunumeric/transform/flip.cu src/cunumeric/arg_redop_register.cu src/cunumeric/cudalibs.cu + src/cunumeric/vectorize/eval_udf.cu + src/cunumeric/vectorize/create_cu_kernel.cu ) endif() diff --git a/docs/cunumeric/source/api/_vectorize.rst b/docs/cunumeric/source/api/_vectorize.rst new file mode 100644 index 0000000000..c096e320de --- /dev/null +++ b/docs/cunumeric/source/api/_vectorize.rst @@ -0,0 +1,14 @@ +cunumeric.vectorize +=================== + +.. currentmodule:: cunumeric.vectorize + +.. autoclass:: vectorize + + .. automethod:: __init__ + + .. rubric:: Methods + + .. automethod:: __call__ + + .. autosummary:: diff --git a/docs/cunumeric/source/api/functional.rst b/docs/cunumeric/source/api/functional.rst new file mode 100644 index 0000000000..4d35618ebf --- /dev/null +++ b/docs/cunumeric/source/api/functional.rst @@ -0,0 +1,7 @@ +Functional programming +====================== + +.. toctree:: + :maxdepth: 2 + + _vectorize diff --git a/docs/cunumeric/source/api/routines.rst b/docs/cunumeric/source/api/routines.rst index e85a5c65b0..5f0451584e 100644 --- a/docs/cunumeric/source/api/routines.rst +++ b/docs/cunumeric/source/api/routines.rst @@ -13,6 +13,7 @@ Routines logic math fft + functional random set sorting diff --git a/pyproject.toml b/pyproject.toml index 73ebc13c82..f577d875ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,12 +39,14 @@ exclude = ''' _build | buck-out | build | - dist + dist | + typings )/ ''' [tool.mypy] python_version = "3.10" +mypy_path = "typings/" pretty = true show_error_codes = true diff --git a/src/cunumeric/cuda_help.h b/src/cunumeric/cuda_help.h index f0f0fee859..00145edc98 100644 --- a/src/cunumeric/cuda_help.h +++ b/src/cunumeric/cuda_help.h @@ -28,6 +28,7 @@ #include #include #include +#include #define THREADS_PER_BLOCK 128 #define MIN_CTAS_PER_SM 4 @@ -115,6 +116,8 @@ cublasHandle_t get_cublas(); cusolverDnHandle_t get_cusolver(); cutensorHandle_t* get_cutensor(); cufftContext get_cufft_plan(cufftType type, const legate::DomainPoint& size); +void store_udf(size_t hash, CUfunction func); +CUfunction get_udf(size_t hash); __host__ inline void check_cublas(cublasStatus_t status, const char* file, int line) { @@ -388,5 +391,4 @@ __device__ __forceinline__ void store_streaming(double* ptr, double valu { asm volatile("st.global.cs.f64 [%0], %1;" : : "l"(ptr), "d"(value) : "memory"); } - } // namespace cunumeric diff --git a/src/cunumeric/cudalibs.cu b/src/cunumeric/cudalibs.cu index 7d3ab8a098..f1e3a95be1 100644 --- a/src/cunumeric/cudalibs.cu +++ b/src/cunumeric/cudalibs.cu @@ -233,6 +233,18 @@ cufftContext CUDALibraries::get_cufft_plan(cufftType type, const DomainPoint& si return cufftContext(cache->get_cufft_plan(size)); } +void CUDALibraries::store_udf_func(size_t hash, CUfunction func) { udf_caches_[hash] = func; } + +CUfunction CUDALibraries::get_udf_func(size_t hash) +{ + auto finder = udf_caches_.find(hash); + if (udf_caches_.end() == finder) { + fprintf(stderr, "UDF function wasn't generated yet"); + LEGATE_ABORT; + } + return udf_caches_[hash]; +} + static CUDALibraries& get_cuda_libraries(legate::Processor proc) { if (proc.kind() != legate::Processor::TOC_PROC) { @@ -278,6 +290,20 @@ cufftContext get_cufft_plan(cufftType type, const DomainPoint& size) return lib.get_cufft_plan(type, size); } +void store_udf(size_t hash, CUfunction func) +{ + const auto proc = legate::Processor::get_executing_processor(); + auto& lib = get_cuda_libraries(proc); + lib.store_udf_func(hash, func); +} + +CUfunction get_udf(size_t hash) +{ + const auto proc = legate::Processor::get_executing_processor(); + auto& lib = get_cuda_libraries(proc); + return lib.get_udf_func(hash); +} + class LoadCUDALibsTask : public CuNumericTask { public: static const int TASK_ID = CUNUMERIC_LOAD_CUDALIBS; diff --git a/src/cunumeric/cudalibs.h b/src/cunumeric/cudalibs.h index f2f01fffe1..8f91f3aad3 100644 --- a/src/cunumeric/cudalibs.h +++ b/src/cunumeric/cudalibs.h @@ -38,6 +38,8 @@ struct CUDALibraries { cusolverDnHandle_t get_cusolver(); cutensorHandle_t* get_cutensor(); cufftContext get_cufft_plan(cufftType type, const legate::DomainPoint& size); + void store_udf_func(size_t hash, CUfunction func); + CUfunction get_udf_func(size_t hash); private: void finalize_cublas(); @@ -50,6 +52,7 @@ struct CUDALibraries { cusolverDnContext* cusolver_; cutensorHandle_t* cutensor_; std::map plan_caches_; + std::map udf_caches_; }; } // namespace cunumeric diff --git a/src/cunumeric/cunumeric_c.h b/src/cunumeric/cunumeric_c.h index c3145939ea..85ddc06593 100644 --- a/src/cunumeric/cunumeric_c.h +++ b/src/cunumeric/cunumeric_c.h @@ -37,11 +37,13 @@ enum CuNumericOpCode { CUNUMERIC_CONTRACT, CUNUMERIC_CONVERT, CUNUMERIC_CONVOLVE, + CUNUMERIC_CREATE_CU_KERNEL, CUNUMERIC_SCAN_GLOBAL, CUNUMERIC_SCAN_LOCAL, CUNUMERIC_DIAG, CUNUMERIC_DOT, CUNUMERIC_EYE, + CUNUMERIC_EVAL_UDF, CUNUMERIC_FFT, CUNUMERIC_FILL, CUNUMERIC_FLIP, diff --git a/src/cunumeric/pitches.h b/src/cunumeric/pitches.h index d2a63f7b32..27d179b0e5 100644 --- a/src/cunumeric/pitches.h +++ b/src/cunumeric/pitches.h @@ -54,6 +54,9 @@ class Pitches { return point; } + __CUDA_HD__ + inline const size_t* data(void) { return &pitches[0]; } + private: size_t pitches[DIM]; }; @@ -90,6 +93,9 @@ class Pitches { return point; } + __CUDA_HD__ + inline const size_t* data(void) { return &pitches[0]; } + private: size_t pitches[DIM]; }; @@ -103,8 +109,10 @@ class Pitches<0, C_ORDER> { { if (rect.lo[0] > rect.hi[0]) return 0; - else + else { + pitches[0] = rect.hi[0] - rect.lo[0] + 1; return (rect.hi[0] - rect.lo[0] + 1); + } } __CUDA_HD__ inline legate::Point<1> unflatten(size_t index, const legate::Point<1>& lo) const @@ -113,6 +121,11 @@ class Pitches<0, C_ORDER> { point[0] += index; return point; } + __CUDA_HD__ + inline const size_t* data(void) { return &pitches[0]; } + + private: + size_t pitches[1]; }; } // namespace cunumeric diff --git a/src/cunumeric/vectorize/create_cu_kernel.cc b/src/cunumeric/vectorize/create_cu_kernel.cc new file mode 100644 index 0000000000..effcb32c95 --- /dev/null +++ b/src/cunumeric/vectorize/create_cu_kernel.cc @@ -0,0 +1,33 @@ +/* Copyright 20223 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/vectorize/create_cu_kernel.h" + +namespace cunumeric { + +using namespace legate; + +/*static*/ void CreateCUKernelTask::cpu_variant(TaskContext& context){}; + +namespace // unnamed +{ +static void __attribute__((constructor)) register_tasks(void) +{ + CreateCUKernelTask::register_variants(); +} +} // namespace + +} // namespace cunumeric diff --git a/src/cunumeric/vectorize/create_cu_kernel.cu b/src/cunumeric/vectorize/create_cu_kernel.cu new file mode 100644 index 0000000000..5805d2ef1c --- /dev/null +++ b/src/cunumeric/vectorize/create_cu_kernel.cu @@ -0,0 +1,91 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/vectorize/create_cu_kernel.h" +#include "cunumeric/cuda_help.h" +#include +#include + +namespace cunumeric { + +using namespace Legion; +using namespace legate; + +/*static*/ void CreateCUKernelTask::gpu_variant(TaskContext& context) +{ + int64_t ptx_hash = context.scalars()[0].value(); + std::string ptx = context.scalars()[1].value(); + Processor point = legate::Processor::get_executing_processor(); + + CUfunction func; + const unsigned num_options = 4; + const size_t log_buffer_size = 16384; + std::vector log_info_buffer(log_buffer_size); + std::vector log_error_buffer(log_buffer_size); + CUjit_option jit_options[] = { + CU_JIT_INFO_LOG_BUFFER, + CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, + CU_JIT_ERROR_LOG_BUFFER, + CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, + }; + void* option_vals[] = { + static_cast(log_info_buffer.data()), + reinterpret_cast(log_buffer_size), + static_cast(log_error_buffer.data()), + reinterpret_cast(log_buffer_size), + }; + + CUmodule module; + CUresult result = cuModuleLoadDataEx(&module, ptx.data(), num_options, jit_options, option_vals); + if (result != CUDA_SUCCESS) { + if (result == CUDA_ERROR_OPERATING_SYSTEM) { + fprintf(stderr, + "ERROR: Device side asserts are not supported by the " + "CUDA driver for MAC OSX, see NVBugs 1628896.\n"); + exit(-1); + } else if (result == CUDA_ERROR_NO_BINARY_FOR_GPU) { + fprintf(stderr, "ERROR: The binary was compiled for the wrong GPU architecture.\n"); + exit(-1); + } else { + fprintf(stderr, "Failed to load CUDA module! Error log: %s\n", log_error_buffer.data()); +#if CUDA_VERSION >= 6050 + const char *name, *str; + assert(cuGetErrorName(result, &name) == CUDA_SUCCESS); + assert(cuGetErrorString(result, &str) == CUDA_SUCCESS); + fprintf(stderr, "CU: cuModuleLoadDataEx = %d (%s): %s\n", result, name, str); +#else + fprintf(stderr, "CU: cuModuleLoadDataEx = %d\n", result); +#endif + exit(-1); + } + } + std::cmatch line_match; + bool match = + std::regex_search(ptx.data(), line_match, std::regex(".visible .entry [_a-zA-Z0-9$]+")); +#ifdef DEBUG_CUNUMERIC + assert(match); +#endif + const auto& matched_line = line_match.begin()->str(); + auto fun_name = matched_line.substr(matched_line.rfind(" ") + 1, matched_line.size()); + + result = cuModuleGetFunction(&func, module, fun_name.c_str()); +#ifdef DEBUG_CUNUMERIC + assert(result == CUDA_SUCCESS); +#endif + store_udf(ptx_hash, func); +} + +} // namespace cunumeric diff --git a/src/cunumeric/vectorize/create_cu_kernel.h b/src/cunumeric/vectorize/create_cu_kernel.h new file mode 100644 index 0000000000..7b1e176756 --- /dev/null +++ b/src/cunumeric/vectorize/create_cu_kernel.h @@ -0,0 +1,38 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +#include "cunumeric/cunumeric.h" +#include "core/data/scalar.h" + +namespace cunumeric { + +class CreateCUKernelTask : public CuNumericTask { + public: + static const int TASK_ID = CUNUMERIC_CREATE_CU_KERNEL; + + public: + static void cpu_variant(legate::TaskContext& context); +#ifdef LEGATE_USE_OPENMP + static void omp_variant(legate::TaskContext& context); +#endif +#ifdef LEGATE_USE_CUDA + static void gpu_variant(legate::TaskContext& context); +#endif +}; + +} // namespace cunumeric diff --git a/src/cunumeric/vectorize/create_cu_kernel_omp.cc b/src/cunumeric/vectorize/create_cu_kernel_omp.cc new file mode 100644 index 0000000000..40cc28f6c7 --- /dev/null +++ b/src/cunumeric/vectorize/create_cu_kernel_omp.cc @@ -0,0 +1,25 @@ +/* Copyright 20223 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/vectorize/create_cu_kernel.h" + +namespace cunumeric { + +using namespace legate; + +/*static*/ void CreateCUKernelTask::omp_variant(TaskContext& context) {} + +} // namespace cunumeric diff --git a/src/cunumeric/vectorize/eval_udf.cc b/src/cunumeric/vectorize/eval_udf.cc new file mode 100644 index 0000000000..17e47efcc4 --- /dev/null +++ b/src/cunumeric/vectorize/eval_udf.cc @@ -0,0 +1,91 @@ +/* Copyright 20223 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/vectorize/eval_udf.h" +#include "cunumeric/pitches.h" + +namespace cunumeric { + +using namespace legate; + +struct EvalUdfCPU { + template + void operator()(EvalUdfArgs& args) const + { + // In the case of CPU, we pack arguments in a vector and pass them to the + // function (through the function pointer generated by numba) + using UDF = void(void**, size_t, size_t, uint32_t*, uint32_t*); + auto udf = reinterpret_cast(args.cpu_func_ptr); + std::vector udf_args; + size_t volume = 1; + Pitches pitches; + Rect rect; + size_t strides[DIM]; + if (args.inputs.size() > 0) { + using VAL = legate_type_of; + rect = args.inputs[0].shape(); + volume = pitches.flatten(rect); + + if (rect.empty()) return; + for (size_t i = 0; i < args.inputs.size(); i++) { + if (i < args.num_outputs) { + auto out = args.outputs[i].write_accessor(rect); + udf_args.push_back(reinterpret_cast(out.ptr(rect, strides))); + } else { + auto out = args.inputs[i].read_accessor(rect); + udf_args.push_back(reinterpret_cast(const_cast(out.ptr(rect, strides)))); + } + } + } // if + for (auto s : args.scalars) udf_args.push_back(const_cast(s.ptr())); + udf(udf_args.data(), + volume, + size_t(DIM), + reinterpret_cast(const_cast(pitches.data())), + reinterpret_cast(&strides[0])); + } +}; + +/*static*/ void EvalUdfTask::cpu_variant(TaskContext& context) +{ + uint32_t num_outputs = context.scalars()[0].value(); + uint32_t num_scalars = context.scalars()[1].value(); + std::vector scalars; + for (size_t i = 2; i < (2 + num_scalars); i++) scalars.push_back(context.scalars()[i]); + + EvalUdfArgs args{context.scalars()[2 + num_scalars].value(), + context.inputs(), + context.outputs(), + scalars, + num_outputs, + legate::Processor::get_executing_processor()}; + int dim = 1; + if (args.inputs.size() > 0) { + dim = args.inputs[0].dim() == 0 ? 1 : args.inputs[0].dim(); + assert(dim > 0); + double_dispatch(dim, args.inputs[0].code(), EvalUdfCPU{}, args); + } else { + Type::Code code = Type::Code::BOOL; + double_dispatch(dim, code, EvalUdfCPU{}, args); + } +} + +namespace // unnamed +{ +static void __attribute__((constructor)) register_tasks(void) { EvalUdfTask::register_variants(); } +} // namespace + +} // namespace cunumeric diff --git a/src/cunumeric/vectorize/eval_udf.cu b/src/cunumeric/vectorize/eval_udf.cu new file mode 100644 index 0000000000..74264a69d6 --- /dev/null +++ b/src/cunumeric/vectorize/eval_udf.cu @@ -0,0 +1,143 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/vectorize/eval_udf.h" +#include "cunumeric/cuda_help.h" +#include "cunumeric/pitches.h" +#include +#include + +namespace cunumeric { + +// using namespace Legion; +using namespace legate; + +struct EvalUdfGPU { + template + void operator()(EvalUdfArgs& args) const + { + using VAL = legate_type_of; + Rect rect; + + size_t input_size = args.inputs.size(); + CUfunction func = get_udf(args.hash); + + // Filling up the buffer with arguments + size_t buffer_size = (input_size + args.scalars.size()) * sizeof(void*); + buffer_size += sizeof(size_t); // size + buffer_size += sizeof(size_t); // dim + buffer_size += sizeof(void*); // pitches + buffer_size += sizeof(void*); // strides + + std::vector arg_buffer(buffer_size); + char* raw_arg_buffer = arg_buffer.data(); + + auto p = raw_arg_buffer; + size_t strides[DIM]; + size_t size = 1; + if (input_size > 0) { + rect = args.inputs[0].shape(); + size = rect.volume(); + for (size_t i = 0; i < input_size; i++) { + if (i < args.num_outputs) { + auto out = args.outputs[i].write_accessor(rect); + *reinterpret_cast(p) = out.ptr(rect, strides); + } else { + auto in = args.inputs[i].read_accessor(rect); + *reinterpret_cast(p) = in.ptr(rect, strides); + } + p += sizeof(void*); + } + } + for (auto scalar : args.scalars) { + memcpy(p, scalar.ptr(), scalar.size()); + p += scalar.size(); + } + memcpy(p, &size, sizeof(size_t)); + size_t dim = DIM; + p += sizeof(size_t); + memcpy(p, &dim, sizeof(size_t)); + p += sizeof(size_t); + Pitches pitches; + size_t volume = pitches.flatten(rect); + // create buffers for pitches and strides since + // we need to pass pointer to device memory + auto device_pitches = create_buffer(Point<1>(DIM - 1), Memory::Kind::Z_COPY_MEM); + auto device_strides = create_buffer(Point<1>(DIM), Memory::Kind::Z_COPY_MEM); + for (size_t i = 0; i < DIM; i++) { + if (i != DIM - 1) { device_pitches[Point<1>(i)] = pitches.data()[i]; } + device_strides[Point<1>(i)] = strides[i]; + } + *reinterpret_cast(p) = device_pitches.ptr(Point<1>(0)); + p += sizeof(void*); + *reinterpret_cast(p) = device_strides.ptr(Point<1>(0)); + p += sizeof(void*); + + void* config[] = { + CU_LAUNCH_PARAM_BUFFER_POINTER, + static_cast(raw_arg_buffer), + CU_LAUNCH_PARAM_BUFFER_SIZE, + &buffer_size, + CU_LAUNCH_PARAM_END, + }; + + const uint32_t gridDimX = (size + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + const uint32_t gridDimY = 1; + const uint32_t gridDimZ = 1; + + const uint32_t blockDimX = THREADS_PER_BLOCK; + const uint32_t blockDimY = 1; + const uint32_t blockDimZ = 1; + + auto stream = get_cached_stream(); + + CUresult status = cuLaunchKernel( + func, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, 0, stream, NULL, config); + if (status != CUDA_SUCCESS) { + fprintf(stderr, "Failed to launch a CUDA kernel\n"); + assert(false); + } + + CHECK_CUDA_STREAM(stream); + } +}; + +/*static*/ void EvalUdfTask::gpu_variant(TaskContext& context) +{ + uint32_t num_outputs = context.scalars()[0].value(); + uint32_t num_scalars = context.scalars()[1].value(); + std::vector scalars; + for (size_t i = 2; i < (2 + num_scalars); i++) scalars.push_back(context.scalars()[i]); + + int64_t ptx_hash = context.scalars()[2 + num_scalars].value(); + + EvalUdfArgs args{0, + context.inputs(), + context.outputs(), + scalars, + num_outputs, + legate::Processor::get_executing_processor(), + ptx_hash}; + size_t dim = 1; + if (args.inputs.size() > 0) { + dim = args.inputs[0].dim() == 0 ? 1 : args.inputs[0].dim(); + double_dispatch(dim, args.inputs[0].code(), EvalUdfGPU{}, args); + } else { + Type::Code code = Type::Code::BOOL; + double_dispatch(dim, code, EvalUdfGPU{}, args); + } +} +} // namespace cunumeric diff --git a/src/cunumeric/vectorize/eval_udf.h b/src/cunumeric/vectorize/eval_udf.h new file mode 100644 index 0000000000..784e2334b1 --- /dev/null +++ b/src/cunumeric/vectorize/eval_udf.h @@ -0,0 +1,48 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +#include "cunumeric/cunumeric.h" +#include "core/data/scalar.h" + +namespace cunumeric { + +struct EvalUdfArgs { + uint64_t cpu_func_ptr; + std::vector& inputs; + std::vector& outputs; + std::vector scalars; + uint32_t num_outputs; + Legion::Processor point; + int64_t hash = 0; +}; + +class EvalUdfTask : public CuNumericTask { + public: + static const int TASK_ID = CUNUMERIC_EVAL_UDF; + + public: + static void cpu_variant(legate::TaskContext& context); +#ifdef LEGATE_USE_OPENMP + static void omp_variant(legate::TaskContext& context); +#endif +#ifdef LEGATE_USE_CUDA + static void gpu_variant(legate::TaskContext& context); +#endif +}; + +} // namespace cunumeric diff --git a/src/cunumeric/vectorize/eval_udf_omp.cc b/src/cunumeric/vectorize/eval_udf_omp.cc new file mode 100644 index 0000000000..c6e2991733 --- /dev/null +++ b/src/cunumeric/vectorize/eval_udf_omp.cc @@ -0,0 +1,29 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/vectorize/eval_udf.h" + +namespace cunumeric { + +using namespace Legion; +using namespace legate; + +/*static*/ void EvalUdfTask::omp_variant(TaskContext& context) +{ + EvalUdfTask::cpu_variant(context); +} + +} // namespace cunumeric diff --git a/tests/integration/test_vectorize.py b/tests/integration/test_vectorize.py new file mode 100644 index 0000000000..df03987fd7 --- /dev/null +++ b/tests/integration/test_vectorize.py @@ -0,0 +1,220 @@ +# Copyright 2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import numpy as np +import pytest +from legate.core import LEGATE_MAX_DIM +from utils.generators import mk_seq_array + +import cunumeric as num + + +def my_func(a, b): + a = a * 2 + b + return a + + +# Capital letters and numbers in the signature +def my_func2(A0, B0): + A0 = A0 * 2 + B0 + C0 = A0 * 2 + return A0, C0 + + +def test_vectorize(): + # 2 arrays + func = num.vectorize(my_func) + a = num.arange(5) + b = num.ones((5,)) + a = func(a, b) + assert np.array_equal(a, [1, 3, 5, 7, 9]) + + # array and scalar + func = num.vectorize(my_func) + a = num.arange(5) + b = 2 + a = func(a, b) + assert np.array_equal(a, [2, 4, 6, 8, 10]) + + # 2 scalars + func = num.vectorize(my_func) + a = 3 + b = 2 + a = func(a, b) + assert a == 8 + + +def empty_func(): + print("within empty function") + + +def test_empty_functions(): + # empty function + func = num.vectorize(empty_func) + func() + + +func_num = num.vectorize(my_func) +func_np = np.vectorize(my_func) + + +@pytest.mark.parametrize( + "slice", + ( + (Ellipsis), + ( + slice(5, 10), + 2, + ), + (slice(3, 7),), + ( + Ellipsis, + 2, + ), + ), +) +def test_vectorize_over_slices(slice): + a = np.arange(160).reshape((10, 4, 4)) + a_num = num.array(a) + b = a * 10 + b_num = num.array(b) + a[slice] = func_np(a[slice], b[slice]) + a_num[slice] = func_num(a_num[slice], b_num[slice]) + assert np.array_equal(a, a_num) + + +def test_multiple_outputs(): + # checking signature with capital letters and numbers + # + checking multiple outputs + a = np.arange(100).reshape((25, 4)) + a_num = num.array(a) + b = a * 10 + b_num = a_num * 10 + func_np = np.vectorize(my_func2) + func_num = num.vectorize(my_func2) + a, c = func_np(a, b) + a_num, c_num = func_num(a_num, b_num) + assert np.array_equal(a, a_num) + assert np.array_equal(c, c_num) + + +def test_different_types(): + # checking the case when input and output types are different + a = np.arange(100, dtype=int).reshape((25, 4)) + a_num = num.array(a) + b = a * 10 + b_num = a_num * 10 + func_np = np.vectorize(my_func, otypes=(float,)) + func_num = num.vectorize(my_func, otypes=(float,)) + a = func_np(a, b) + a_num = func_num(a_num, b_num) + assert np.array_equal(a, a_num) + + # another test for different types + a = np.arange(100, dtype=float).reshape((25, 4)) + a_num = num.array(a) + b = a * 10 + b_num = a_num * 10 + func_np = np.vectorize( + my_func2, + otypes=( + int, + int, + ), + ) + func_num = num.vectorize( + my_func2, + otypes=( + int, + int, + ), + ) + a, c = func_np(a, b) + a_num, c_num = func_num(a_num, b_num) + assert np.array_equal(a, a_num) + assert np.array_equal(c, c_num) + + +def test_cache_multiple_outputs(): + a = np.arange(100).reshape((25, 4)) + a_num = num.array(a) + b = a * 10 + b_num = a_num * 10 + func_np = np.vectorize(my_func2, cache=True) + func_num = num.vectorize(my_func2, cache=True) + for i in range(10): + a = a * 2 + b = b * 3 + a_num = a_num * 2 + b_num = b_num * 3 + a, c = func_np(a, b) + a_num, c_num = func_num(a_num, b_num) + assert np.array_equal(a, a_num) + assert np.array_equal(c, c_num) + + a_num = a_num.astype(float) + b_num = b_num.astype(float) + msg = r"types of the arguments should stay the same" + with pytest.raises(TypeError, match=msg): + a_num = func_num(a_num, b_num) + + +def test_cache_single_output(): + a = np.arange(100).reshape((2, 50)) + a_num = num.array(a) + b = a * 10 + b_num = a_num * 10 + func_np = np.vectorize(my_func, cache=True) + func_num = num.vectorize(my_func, cache=True) + for i in range(10): + a = a * 2 + b = b * 3 + a_num = a_num * 2 + b_num = b_num * 3 + a = func_np(a, b) + a_num = func_num(a_num, b_num) + assert np.array_equal(a, a_num) + + a_num = a_num.astype(float) + b_num = b_num.astype(float) + msg = r"types of the arguments should stay the same" + with pytest.raises(TypeError, match=msg): + a_num = func_num(a_num, b_num) + + +# checking caching on different shapes of arrays: +func_np2 = np.vectorize(my_func2, cache=True) +func_num2 = num.vectorize(my_func2, cache=True) + + +@pytest.mark.parametrize("ndim", range(1, LEGATE_MAX_DIM + 1)) +def test_nd_vectorize(ndim): + a_shape = tuple(np.random.randint(1, 9) for _ in range(ndim)) + a = mk_seq_array(np, a_shape) + a_num = num.array(a) + b = a * 2 + b_num = num.array(b) + a, c = func_np2(a, b) + a_num, c_num = func_num2(a_num, b_num) + assert np.array_equal(a, a_num) + assert np.array_equal(c, c_num) + + +if __name__ == "__main__": + import sys + + np.random.seed(12345) + sys.exit(pytest.main(sys.argv)) diff --git a/typings/numba/__init__.pyi b/typings/numba/__init__.pyi new file mode 100644 index 0000000000..3aa25ebbd1 --- /dev/null +++ b/typings/numba/__init__.pyi @@ -0,0 +1,10 @@ +from typing import Any, Callable + +import numba.core.types as types +import numba.cuda # import compile_ptx +from numba.core import types +from numba.core.ccallback import CFunc +from numba.core.types import CPointer, uint64 + +def cfunc(sig: Any) -> Any: + def wrapper(func: Callable[[Any], Any]) -> tuple[Any]: ... diff --git a/typings/numba/core/__init__.pyi b/typings/numba/core/__init__.pyi new file mode 100644 index 0000000000..e69de29bb2 diff --git a/typings/numba/core/ccallback/__init__.pyi b/typings/numba/core/ccallback/__init__.pyi new file mode 100644 index 0000000000..81b5030b9c --- /dev/null +++ b/typings/numba/core/ccallback/__init__.pyi @@ -0,0 +1,8 @@ +from typing import Any + +class CFunc(object): + def __init__( + self, pyfunc: Any, sig: Any, locals: Any, options: Any + ) -> None: ... + @property + def address(self) -> int: ... diff --git a/typings/numba/core/types/__init__.pyi b/typings/numba/core/types/__init__.pyi new file mode 100644 index 0000000000..8bb1e2b103 --- /dev/null +++ b/typings/numba/core/types/__init__.pyi @@ -0,0 +1,25 @@ +class Opaque: ... + +class NoneType(Opaque): + def __init__(self, name: str) -> None: ... + +class Type: + def __init__(self, name: str) -> None: ... + +class Number(Type): ... + +class Integer(Number): + def __init__(self, name: str) -> None: ... + +class RawPointer: + def __init__(self, name: str) -> None: ... + +class CPointer(Type): + def __init__(self, dtype: Type) -> None: ... + +none = NoneType("none") + +uint32 = Integer("uint32") +uint64 = Integer("uint64") +void = none +voidptr = Type("void*") diff --git a/typings/numba/cuda/__init__.pyi b/typings/numba/cuda/__init__.pyi new file mode 100644 index 0000000000..d66e40c5f4 --- /dev/null +++ b/typings/numba/cuda/__init__.pyi @@ -0,0 +1,5 @@ +from typing import Any + +from numba.cuda.compiler import compile_ptx as compile_ptx + +def get_current_device() -> Any: ... diff --git a/typings/numba/cuda/compiler.pyi b/typings/numba/cuda/compiler.pyi new file mode 100644 index 0000000000..56e02dd3e2 --- /dev/null +++ b/typings/numba/cuda/compiler.pyi @@ -0,0 +1,12 @@ +from typing import Any, Callable, Optional + +def compile_ptx( + pyfunc: Callable[[Any], Any], + args: Any, + debug: bool = False, + lineinfo: bool = False, + device: bool = False, + fastmath: bool = False, + cc: Optional[Any] = None, + opt: bool = True, +) -> tuple[Any]: ... diff --git a/typings/numba/types/CPointer.pyi b/typings/numba/types/CPointer.pyi new file mode 100644 index 0000000000..249a23f191 --- /dev/null +++ b/typings/numba/types/CPointer.pyi @@ -0,0 +1,5 @@ +# import numpy as np +from numba.core.types.abstract import Type + +class CPointer(Type): + def __init__(self, dtype: Type) -> None: ... diff --git a/typings/numba/types/__init__.pyi b/typings/numba/types/__init__.pyi new file mode 100644 index 0000000000..14c90eca2a --- /dev/null +++ b/typings/numba/types/__init__.pyi @@ -0,0 +1,12 @@ +class Type: ... +class Number(Type): ... + +class Integer(Number): + def __init__(self, name: str) -> None: ... + +class CPointer(Type): + def __init__(self, dtype: Type) -> None: ... + +uint32 = Integer("uint32") +uint64 = Integer("uint64") +void = None diff --git a/typings/pyarrow/__init__.pyi b/typings/pyarrow/__init__.pyi new file mode 100644 index 0000000000..cc2ac93aa9 --- /dev/null +++ b/typings/pyarrow/__init__.pyi @@ -0,0 +1,136 @@ +# Copyright 2021-2022 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Union + +from .lib import ( + DataType, + binary, + bool_, + float16, + float32, + float64, + int8, + int16, + int32, + int64, + string, + uint8, + uint16, + uint32, + uint64, +) + +class Field: + name: str + type: DataType + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def with_name(self, name: str) -> Field: ... + +def field( + name: Union[str, bytes], + type: DataType, + nullable: bool = True, + metadata: Any = None, +) -> Field: ... + +class Schema: + types: Any + def field(self, i: Union[str, int]) -> Field: ... + def get_all_field_indices(self, name: str) -> list[int]: ... + def get_field_index(self, name: str) -> int: ... + def __len__(self) -> int: ... + def __getitem__(self, idx: int) -> Field: ... + +def schema(fields: Any, metadata: Any = None) -> Schema: ... + +class ExtensionType: + def __init__(self, dtype: DataType, name: str) -> None: ... + +class DictionaryType: ... +class ListType: ... +class MapType: ... +class StructType: ... +class UnionType: ... +class TimestampType: ... +class Time32Type: ... +class Time64Type: ... +class FixedSizeBinaryType: ... +class Decimal128Type: ... +class time32: ... +class time64: ... +class timestamp: ... +class date32: ... +class date64: ... +class large_binary: ... +class large_string: ... +class large_utf8: ... +class decimal128: ... +class large_list: ... +class struct: ... +class dictionary: ... +class null: ... +class utf8: ... +class list_: ... +class map_: ... + +def from_numpy_dtype(dtype: Any) -> DataType: ... + +__all__ = ( + "binary", + "bool_", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float16", + "float32", + "float64", + "Field", + "Schema", + "DataType", + "DictionaryType", + "ListType", + "MapType", + "StructType", + "UnionType", + "TimestampType", + "Time32Type", + "Time64Type", + "FixedSizeBinaryType", + "Decimal128Type", + "time32", + "time64", + "timestamp", + "date32", + "date64", + "string", + "large_binary", + "large_string", + "large_utf8", + "decimal128", + "large_list", + "struct", + "dictionary", + "null", + "utf8", + "list_", + "map_", + "from_numpy_dtype", +) diff --git a/typings/pyarrow/lib.pyi b/typings/pyarrow/lib.pyi new file mode 100644 index 0000000000..398361089b --- /dev/null +++ b/typings/pyarrow/lib.pyi @@ -0,0 +1,38 @@ +# Copyright 2021-2022 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any + +class DataType: + id: int + num_fields: int + num_buffers: int + def equals(self, other: object) -> bool: ... + def to_pandas_dtype(self) -> Any: ... + +def binary(length: int) -> DataType: ... +def bool_() -> DataType: ... +def int8() -> DataType: ... +def int16() -> DataType: ... +def int32() -> DataType: ... +def int64() -> DataType: ... +def uint8() -> DataType: ... +def uint16() -> DataType: ... +def uint32() -> DataType: ... +def uint64() -> DataType: ... +def float16() -> DataType: ... +def float32() -> DataType: ... +def float64() -> DataType: ... +def string() -> DataType: ...