Skip to content

Commit a3e9d9c

Browse files
committed
Only use thread for a batch of rows not each row
1 parent ae877bd commit a3e9d9c

File tree

2 files changed

+41
-9
lines changed

2 files changed

+41
-9
lines changed

singlestoredb/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,13 @@
468468
environ=['SINGLESTOREDB_EXT_FUNC_APP_NAME'],
469469
)
470470

471+
register_option(
472+
'external_function.concurrency_limit', 'int', check_int, 1,
473+
'Specifies the maximum number of subsets of a batch of rows '
474+
'to process simultaneously.',
475+
environ=['SINGLESTOREDB_EXT_FUNC_CONCURRENCY_LIMIT'],
476+
)
477+
471478
#
472479
# Debugging options
473480
#

singlestoredb/functions/ext/asgi.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
from typing import Any
5656
from typing import Callable
5757
from typing import Dict
58+
from typing import Iterable
59+
from typing import Iterator
5860
from typing import List
5961
from typing import Optional
6062
from typing import Set
@@ -288,11 +290,21 @@ def identity(x: Any) -> Any:
288290
return x
289291

290292

293+
def chunked(seq: Sequence[Any], max_chunks: int) -> Iterator[Sequence[Any]]:
294+
"""Yield up to max_chunks chunks from seq, splitting as evenly as possible."""
295+
n = len(seq)
296+
if max_chunks <= 0 or max_chunks > n:
297+
max_chunks = n
298+
chunk_size = (n + max_chunks - 1) // max_chunks # ceil division
299+
for i in range(0, n, chunk_size):
300+
yield seq[i:i + chunk_size]
301+
302+
291303
async def run_in_parallel(
292304
func: Callable[..., Any],
293305
params_list: Sequence[Sequence[Any]],
294306
cancel_event: threading.Event,
295-
limit: int = 10,
307+
limit: int = get_option('external_function.concurrency_limit'),
296308
transformer: Callable[[Any], Any] = identity,
297309
) -> List[Any]:
298310
""""
@@ -308,26 +320,39 @@ async def run_in_parallel(
308320
The event to check for cancellation
309321
limit : int
310322
The maximum number of concurrent tasks to run
323+
transformer : Callable[[Any], Any]
324+
A function to transform the results
311325
312326
Returns
313327
-------
314328
List[Any]
315329
The results of the function calls
316330
317331
"""
318-
semaphore = asyncio.Semaphore(limit)
332+
is_async = asyncio.iscoroutinefunction(func)
319333

320-
async def worker(params: Sequence[Any]) -> Any:
321-
async with semaphore:
334+
async def call(batch: Sequence[Any]) -> Any:
335+
"""Loop over batches of parameters and call the function."""
336+
res = []
337+
for params in batch:
322338
cancel_on_event(cancel_event)
323-
if asyncio.iscoroutinefunction(func):
324-
return transformer(await func(*params))
339+
if is_async:
340+
res.append(transformer(await func(*params)))
325341
else:
326-
return transformer(await to_thread(func, *params))
342+
res.append(transformer(func(*params)))
343+
return res
344+
345+
async def thread_call(batch: Sequence[Any]) -> Any:
346+
if is_async:
347+
return await call(batch)
348+
return await to_thread(lambda: asyncio.run(call(batch)))
349+
350+
# Create tasks in chunks to limit concurrency
351+
tasks = [thread_call(batch) for batch in chunked(params_list, limit)]
327352

328-
tasks = [worker(p) for p in params_list]
353+
results = await asyncio.gather(*tasks)
329354

330-
return await asyncio.gather(*tasks)
355+
return list(itertools.chain.from_iterable(results))
331356

332357

333358
def build_udf_endpoint(

0 commit comments

Comments
 (0)