9292
9393logger = utils .get_logger ('singlestoredb.functions.ext.asgi' )
9494
95- # If a number of processes is specified, create a pool of workers
96- num_processes = max (0 , int (os .environ .get ('SINGLESTOREDB_EXT_NUM_PROCESSES' , 0 )))
97- if num_processes > 1 :
98- try :
99- from ray .util .multiprocessing import Pool
100- except ImportError :
101- from multiprocessing import Pool
102- func_map = Pool (num_processes ).starmap
103- else :
104- func_map = itertools .starmap
105-
106-
10795async def to_thread (
10896 func : Any , / , * args : Any , ** kwargs : Dict [str , Any ],
10997) -> Any :
@@ -295,6 +283,53 @@ def cancel_on_event(
295283 )
296284
297285
286+ def identity (x : Any ) -> Any :
287+ """Identity function."""
288+ return x
289+
290+
291+ async def run_in_parallel (
292+ func : Callable [..., Any ],
293+ params_list : Sequence [Sequence [Any ]],
294+ cancel_event : threading .Event ,
295+ limit : int = 10 ,
296+ transformer : Callable [[Any ], Any ] = identity ,
297+ ) -> List [Any ]:
298+ """"
299+ Run a function in parallel with a limit on the number of concurrent tasks.
300+
301+ Parameters
302+ ----------
303+ func : Callable
304+ The function to call in parallel
305+ params_list : Sequence[Sequence[Any]]
306+ The parameters to pass to the function
307+ cancel_event : threading.Event
308+ The event to check for cancellation
309+ limit : int
310+ The maximum number of concurrent tasks to run
311+
312+ Returns
313+ -------
314+ List[Any]
315+ The results of the function calls
316+
317+ """
318+ semaphore = asyncio .Semaphore (limit )
319+
320+ async def worker (params : Sequence [Any ]) -> Any :
321+ async with semaphore :
322+ cancel_on_event (cancel_event )
323+ if asyncio .iscoroutinefunction (func ):
324+ return transformer (await func (* params ))
325+ else :
326+ return transformer (await to_thread (func , * params ))
327+
328+ tasks = [worker (p ) for p in params_list ]
329+
330+ return await asyncio .gather (* tasks )
331+
332+
298333def build_udf_endpoint (
299334 func : Callable [..., Any ],
300335 returns_data_format : str ,
@@ -317,23 +352,15 @@ def build_udf_endpoint(
317352 """
318353 if returns_data_format in ['scalar' , 'list' ]:
319354
320- is_async = asyncio .iscoroutinefunction (func )
321-
322355 async def do_func (
323356 cancel_event : threading .Event ,
324357 timer : Timer ,
325358 row_ids : Sequence [int ],
326359 rows : Sequence [Sequence [Any ]],
327360 ) -> Tuple [Sequence [int ], List [Tuple [Any , ...]]]:
328361 '''Call function on given rows of data.'''
329- out = []
330362 async with timer ('call_function' ):
331- for row in rows :
332- cancel_on_event (cancel_event )
333- if is_async :
334- out .append (await func (* row ))
335- else :
336- out .append (func (* row ))
363+ out = await run_in_parallel (func , rows , cancel_event )
337364 return row_ids , list (zip (out ))
338365
339366 return do_func
@@ -428,28 +455,20 @@ def build_tvf_endpoint(
428455 """
429456 if returns_data_format in ['scalar' , 'list' ]:
430457
431- is_async = asyncio .iscoroutinefunction (func )
432-
433458 async def do_func (
434459 cancel_event : threading .Event ,
435460 timer : Timer ,
436461 row_ids : Sequence [int ],
437462 rows : Sequence [Sequence [Any ]],
438463 ) -> Tuple [Sequence [int ], List [Tuple [Any , ...]]]:
439464 '''Call function on given rows of data.'''
440- out_ids : List [int ] = []
441- out = []
442- # Call function on each row of data
443465 async with timer ('call_function' ):
444- for i , row in zip (row_ids , rows ):
445- cancel_on_event (cancel_event )
446- if is_async :
447- res = await func (* row )
448- else :
449- res = func (* row )
450- out .extend (as_list_of_tuples (res ))
451- out_ids .extend ([row_ids [i ]] * (len (out )- len (out_ids )))
452- return out_ids , out
466+ items = await run_in_parallel (
467+ func , rows , cancel_event ,
468+ transformer = as_list_of_tuples ,
469+ )
470+ out = list (itertools .chain .from_iterable (items ))
471+ return [row_ids [0 ]] * len (out ), out
453472
454473 return do_func
455474
0 commit comments