Skip to content

Commit ae877bd

Browse files
committed
Initial implementation of parallel UDFs using async
1 parent 377ddcf commit ae877bd

File tree

1 file changed

+54
-35
lines changed

1 file changed

+54
-35
lines changed

singlestoredb/functions/ext/asgi.py

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,6 @@
9292

9393
logger = utils.get_logger('singlestoredb.functions.ext.asgi')
9494

95-
# If a number of processes is specified, create a pool of workers
96-
num_processes = max(0, int(os.environ.get('SINGLESTOREDB_EXT_NUM_PROCESSES', 0)))
97-
if num_processes > 1:
98-
try:
99-
from ray.util.multiprocessing import Pool
100-
except ImportError:
101-
from multiprocessing import Pool
102-
func_map = Pool(num_processes).starmap
103-
else:
104-
func_map = itertools.starmap
105-
106-
10795
async def to_thread(
10896
func: Any, /, *args: Any, **kwargs: Dict[str, Any],
10997
) -> Any:
@@ -295,6 +283,53 @@ def cancel_on_event(
295283
)
296284

297285

286+
def identity(x: Any) -> Any:
287+
"""Identity function."""
288+
return x
289+
290+
291+
async def run_in_parallel(
292+
func: Callable[..., Any],
293+
params_list: Sequence[Sequence[Any]],
294+
cancel_event: threading.Event,
295+
limit: int = 10,
296+
transformer: Callable[[Any], Any] = identity,
297+
) -> List[Any]:
298+
""""
299+
Run a function in parallel with a limit on the number of concurrent tasks.
300+
301+
Parameters
302+
----------
303+
func : Callable
304+
The function to call in parallel
305+
params_list : Sequence[Sequence[Any]]
306+
The parameters to pass to the function
307+
cancel_event : threading.Event
308+
The event to check for cancellation
309+
limit : int
310+
The maximum number of concurrent tasks to run
311+
312+
Returns
313+
-------
314+
List[Any]
315+
The results of the function calls
316+
317+
"""
318+
semaphore = asyncio.Semaphore(limit)
319+
320+
async def worker(params: Sequence[Any]) -> Any:
321+
async with semaphore:
322+
cancel_on_event(cancel_event)
323+
if asyncio.iscoroutinefunction(func):
324+
return transformer(await func(*params))
325+
else:
326+
return transformer(await to_thread(func, *params))
327+
328+
tasks = [worker(p) for p in params_list]
329+
330+
return await asyncio.gather(*tasks)
331+
332+
298333
def build_udf_endpoint(
299334
func: Callable[..., Any],
300335
returns_data_format: str,
@@ -317,23 +352,15 @@ def build_udf_endpoint(
317352
"""
318353
if returns_data_format in ['scalar', 'list']:
319354

320-
is_async = asyncio.iscoroutinefunction(func)
321-
322355
async def do_func(
323356
cancel_event: threading.Event,
324357
timer: Timer,
325358
row_ids: Sequence[int],
326359
rows: Sequence[Sequence[Any]],
327360
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
328361
'''Call function on given rows of data.'''
329-
out = []
330362
async with timer('call_function'):
331-
for row in rows:
332-
cancel_on_event(cancel_event)
333-
if is_async:
334-
out.append(await func(*row))
335-
else:
336-
out.append(func(*row))
363+
out = await run_in_parallel(func, rows, cancel_event)
337364
return row_ids, list(zip(out))
338365

339366
return do_func
@@ -428,28 +455,20 @@ def build_tvf_endpoint(
428455
"""
429456
if returns_data_format in ['scalar', 'list']:
430457

431-
is_async = asyncio.iscoroutinefunction(func)
432-
433458
async def do_func(
434459
cancel_event: threading.Event,
435460
timer: Timer,
436461
row_ids: Sequence[int],
437462
rows: Sequence[Sequence[Any]],
438463
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
439464
'''Call function on given rows of data.'''
440-
out_ids: List[int] = []
441-
out = []
442-
# Call function on each row of data
443465
async with timer('call_function'):
444-
for i, row in zip(row_ids, rows):
445-
cancel_on_event(cancel_event)
446-
if is_async:
447-
res = await func(*row)
448-
else:
449-
res = func(*row)
450-
out.extend(as_list_of_tuples(res))
451-
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
452-
return out_ids, out
466+
items = await run_in_parallel(
467+
func, rows, cancel_event,
468+
transformer=as_list_of_tuples,
469+
)
470+
out = list(itertools.chain.from_iterable(items))
471+
return [row_ids[0]] * len(out), out
453472

454473
return do_func
455474

0 commit comments

Comments
 (0)