diff --git a/singlestoredb/config.py b/singlestoredb/config.py index 71c75173..5c0a721b 100644 --- a/singlestoredb/config.py +++ b/singlestoredb/config.py @@ -468,6 +468,13 @@ environ=['SINGLESTOREDB_EXT_FUNC_APP_NAME'], ) +register_option( + 'external_function.concurrency_limit', 'int', check_int, 1, + 'Specifies the maximum number of subsets of a batch of rows ' + 'to process simultaneously.', + environ=['SINGLESTOREDB_EXT_FUNC_CONCURRENCY_LIMIT'], +) + # # Debugging options # diff --git a/singlestoredb/functions/decorator.py b/singlestoredb/functions/decorator.py index 3da98ff4..f9666f01 100644 --- a/singlestoredb/functions/decorator.py +++ b/singlestoredb/functions/decorator.py @@ -104,6 +104,7 @@ def _func( args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, timeout: Optional[int] = None, + concurrency_limit: Optional[int] = None, ) -> UDFType: """Generic wrapper for UDF and TVF decorators.""" @@ -113,6 +114,7 @@ def _func( args=expand_types(args), returns=expand_types(returns), timeout=timeout, + concurrency_limit=concurrency_limit, ).items() if v is not None } @@ -156,6 +158,7 @@ def udf( args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, timeout: Optional[int] = None, + concurrency_limit: Optional[int] = None, ) -> UDFType: """ Define a user-defined function (UDF). @@ -186,6 +189,10 @@ def udf( timeout : int, optional The timeout in seconds for the UDF execution. If not specified, the global default timeout is used. + concurrency_limit : int, optional + The maximum number of concurrent subsets of rows that will be + processed simultaneously by the UDF. If not specified, + the global default concurrency limit is used. Returns ------- @@ -198,4 +205,5 @@ def udf( args=args, returns=returns, timeout=timeout, + concurrency_limit=concurrency_limit, ) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index d9090b38..27f50f95 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -50,6 +50,7 @@ import zipimport from collections.abc import Awaitable from collections.abc import Iterable +from collections.abc import Iterator from collections.abc import Sequence from types import ModuleType from typing import Any @@ -92,17 +93,6 @@ logger = utils.get_logger('singlestoredb.functions.ext.asgi') -# If a number of processes is specified, create a pool of workers -num_processes = max(0, int(os.environ.get('SINGLESTOREDB_EXT_NUM_PROCESSES', 0))) -if num_processes > 1: - try: - from ray.util.multiprocessing import Pool - except ImportError: - from multiprocessing import Pool - func_map = Pool(num_processes).starmap -else: - func_map = itertools.starmap - async def to_thread( func: Any, /, *args: Any, **kwargs: Dict[str, Any], @@ -295,6 +285,102 @@ def cancel_on_event( ) +def identity(x: Any) -> Any: + """Identity function.""" + return x + + +def chunked(seq: Sequence[Any], max_chunks: int) -> Iterator[Sequence[Any]]: + """Yield up to max_chunks chunks from seq, splitting as evenly as possible.""" + n = len(seq) + if max_chunks <= 0 or max_chunks > n: + max_chunks = n + chunk_size = (n + max_chunks - 1) // max_chunks # ceil division + for i in range(0, n, chunk_size): + yield seq[i:i + chunk_size] + + +async def run_in_parallel( + func: Callable[..., Any], + params_list: Sequence[Sequence[Any]], + cancel_event: threading.Event, + transformer: Callable[[Any], Any] = identity, +) -> List[Any]: + """" + Run a function in parallel with a limit on the number of concurrent tasks. + + Parameters + ---------- + func : Callable + The function to call in parallel + params_list : Sequence[Sequence[Any]] + The parameters to pass to the function + cancel_event : threading.Event + The event to check for cancellation + transformer : Callable[[Any], Any] + A function to transform the results + + Returns + ------- + List[Any] + The results of the function calls + + """ + limit = get_concurrency_limit(func) + is_async = asyncio.iscoroutinefunction(func) + + async def call_sync(batch: Sequence[Any]) -> Any: + """Loop over batches of parameters and call the sync function.""" + res = [] + for params in batch: + cancel_on_event(cancel_event) + res.append(transformer(func(*params))) + return res + + async def call_async(batch: Sequence[Any]) -> Any: + """Loop over batches of parameters and call the async function.""" + res = [] + for params in batch: + cancel_on_event(cancel_event) + res.append(transformer(await func(*params))) + return res + + async def thread_call(batch: Sequence[Any]) -> Any: + if is_async: + return await call_async(batch) + return await to_thread(lambda: asyncio.run(call_sync(batch))) + + # Create tasks in chunks to limit concurrency + tasks = [thread_call(batch) for batch in chunked(params_list, limit)] + + results = await asyncio.gather(*tasks) + + return list(itertools.chain.from_iterable(results)) + + +def get_concurrency_limit(func: Callable[..., Any]) -> int: + """ + Get the concurrency limit for a function. + + Parameters + ---------- + func : Callable + The function to get the concurrency limit for + + Returns + ------- + int + The concurrency limit for the function + + """ + return max( + 1, func._singlestoredb_attrs.get( # type: ignore + 'concurrency_limit', + get_option('external_function.concurrency_limit'), + ), + ) + + def build_udf_endpoint( func: Callable[..., Any], returns_data_format: str, @@ -317,8 +403,6 @@ def build_udf_endpoint( """ if returns_data_format in ['scalar', 'list']: - is_async = asyncio.iscoroutinefunction(func) - async def do_func( cancel_event: threading.Event, timer: Timer, @@ -326,14 +410,8 @@ async def do_func( rows: Sequence[Sequence[Any]], ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: '''Call function on given rows of data.''' - out = [] async with timer('call_function'): - for row in rows: - cancel_on_event(cancel_event) - if is_async: - out.append(await func(*row)) - else: - out.append(func(*row)) + out = await run_in_parallel(func, rows, cancel_event) return row_ids, list(zip(out)) return do_func @@ -428,8 +506,6 @@ def build_tvf_endpoint( """ if returns_data_format in ['scalar', 'list']: - is_async = asyncio.iscoroutinefunction(func) - async def do_func( cancel_event: threading.Event, timer: Timer, @@ -437,19 +513,13 @@ async def do_func( rows: Sequence[Sequence[Any]], ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: '''Call function on given rows of data.''' - out_ids: List[int] = [] - out = [] - # Call function on each row of data async with timer('call_function'): - for i, row in zip(row_ids, rows): - cancel_on_event(cancel_event) - if is_async: - res = await func(*row) - else: - res = func(*row) - out.extend(as_list_of_tuples(res)) - out_ids.extend([row_ids[i]] * (len(out)-len(out_ids))) - return out_ids, out + items = await run_in_parallel( + func, rows, cancel_event, + transformer=as_list_of_tuples, + ) + out = list(itertools.chain.from_iterable(items)) + return [row_ids[0]] * len(out), out return do_func