5555from typing import Any
5656from typing import Callable
5757from typing import Dict
58+ from typing import Iterable
59+ from typing import Iterator
5860from typing import List
5961from typing import Optional
6062from typing import Set
@@ -288,11 +290,21 @@ def identity(x: Any) -> Any:
288290 return x
289291
290292
293+ def chunked (seq : Sequence [Any ], max_chunks : int ) -> Iterator [Sequence [Any ]]:
294+ """Yield up to max_chunks chunks from seq, splitting as evenly as possible."""
295+ n = len (seq )
296+ if max_chunks <= 0 or max_chunks > n :
297+ max_chunks = n
298+ chunk_size = (n + max_chunks - 1 ) // max_chunks # ceil division
299+ for i in range (0 , n , chunk_size ):
300+ yield seq [i :i + chunk_size ]
301+
302+
291303async def run_in_parallel (
292304 func : Callable [..., Any ],
293305 params_list : Sequence [Sequence [Any ]],
294306 cancel_event : threading .Event ,
295- limit : int = 10 ,
307+ limit : int = get_option ( 'external_function.concurrency_limit' ) ,
296308 transformer : Callable [[Any ], Any ] = identity ,
297309) -> List [Any ]:
298310 """"
@@ -308,26 +320,39 @@ async def run_in_parallel(
308320 The event to check for cancellation
309321 limit : int
310322 The maximum number of concurrent tasks to run
323+ transformer : Callable[[Any], Any]
324+ A function to transform the results
311325
312326 Returns
313327 -------
314328 List[Any]
315329 The results of the function calls
316330
317331 """
318- semaphore = asyncio .Semaphore ( limit )
332+ is_async = asyncio .iscoroutinefunction ( func )
319333
320- async def worker (params : Sequence [Any ]) -> Any :
321- async with semaphore :
334+ async def call (batch : Sequence [Any ]) -> Any :
335+ """Loop over batches of parameters and call the function."""
336+ res = []
337+ for params in batch :
322338 cancel_on_event (cancel_event )
323- if asyncio . iscoroutinefunction ( func ) :
324- return transformer (await func (* params ))
339+ if is_async :
340+ res . append ( transformer (await func (* params ) ))
325341 else :
326- return transformer (await to_thread (func , * params ))
342+ res .append (transformer (func (* params )))
343+ return res
344+
345+ async def thread_call (batch : Sequence [Any ]) -> Any :
346+ if is_async :
347+ return await call (batch )
348+ return await to_thread (lambda : asyncio .run (call (batch )))
349+
350+ # Create tasks in chunks to limit concurrency
351+ tasks = [thread_call (batch ) for batch in chunked (params_list , limit )]
327352
328- tasks = [ worker ( p ) for p in params_list ]
353+ results = await asyncio . gather ( * tasks )
329354
330- return await asyncio . gather ( * tasks )
355+ return list ( itertools . chain . from_iterable ( results ) )
331356
332357
333358def build_udf_endpoint (
0 commit comments