Skip to content

Commit 2625f9b

Browse files
committed
Add concurrency limit option
1 parent a3e9d9c commit 2625f9b

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

singlestoredb/functions/decorator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def _func(
104104
args: Optional[ParameterType] = None,
105105
returns: Optional[ReturnType] = None,
106106
timeout: Optional[int] = None,
107+
concurrency_limit: Optional[int] = None,
107108
) -> UDFType:
108109
"""Generic wrapper for UDF and TVF decorators."""
109110

@@ -113,6 +114,7 @@ def _func(
113114
args=expand_types(args),
114115
returns=expand_types(returns),
115116
timeout=timeout,
117+
concurrency_limit=concurrency_limit,
116118
).items() if v is not None
117119
}
118120

@@ -156,6 +158,7 @@ def udf(
156158
args: Optional[ParameterType] = None,
157159
returns: Optional[ReturnType] = None,
158160
timeout: Optional[int] = None,
161+
concurrency_limit: Optional[int] = None,
159162
) -> UDFType:
160163
"""
161164
Define a user-defined function (UDF).
@@ -186,6 +189,10 @@ def udf(
186189
timeout : int, optional
187190
The timeout in seconds for the UDF execution. If not specified,
188191
the global default timeout is used.
192+
concurrency_limit : int, optional
193+
The maximum number of concurrent subsets of rows that will be
194+
processed simultaneously by the UDF. If not specified,
195+
the global default concurrency limit is used.
189196
190197
Returns
191198
-------
@@ -198,4 +205,5 @@ def udf(
198205
args=args,
199206
returns=returns,
200207
timeout=timeout,
208+
concurrency_limit=concurrency_limit,
201209
)

singlestoredb/functions/ext/asgi.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,6 @@ async def run_in_parallel(
304304
func: Callable[..., Any],
305305
params_list: Sequence[Sequence[Any]],
306306
cancel_event: threading.Event,
307-
limit: int = get_option('external_function.concurrency_limit'),
308307
transformer: Callable[[Any], Any] = identity,
309308
) -> List[Any]:
310309
""""
@@ -318,8 +317,6 @@ async def run_in_parallel(
318317
The parameters to pass to the function
319318
cancel_event : threading.Event
320319
The event to check for cancellation
321-
limit : int
322-
The maximum number of concurrent tasks to run
323320
transformer : Callable[[Any], Any]
324321
A function to transform the results
325322
@@ -329,6 +326,7 @@ async def run_in_parallel(
329326
The results of the function calls
330327
331328
"""
329+
limit = get_concurrency_limit(func)
332330
is_async = asyncio.iscoroutinefunction(func)
333331

334332
async def call(batch: Sequence[Any]) -> Any:
@@ -355,6 +353,29 @@ async def thread_call(batch: Sequence[Any]) -> Any:
355353
return list(itertools.chain.from_iterable(results))
356354

357355

356+
def get_concurrency_limit(func: Callable[..., Any]) -> int:
357+
"""
358+
Get the concurrency limit for a function.
359+
360+
Parameters
361+
----------
362+
func : Callable
363+
The function to get the concurrency limit for
364+
365+
Returns
366+
-------
367+
int
368+
The concurrency limit for the function
369+
370+
"""
371+
return max(
372+
1, func._singlestoredb_attrs.get( # type: ignore
373+
'concurrency_limit',
374+
get_option('external_function.concurrency_limit'),
375+
),
376+
)
377+
378+
358379
def build_udf_endpoint(
359380
func: Callable[..., Any],
360381
returns_data_format: str,

0 commit comments

Comments
 (0)