66from dataclasses import dataclass
77from functools import partial , wraps
88from inspect import signature
9- from typing import ParamSpec , TypeVar
9+ from itertools import islice
10+ from typing import (
11+ ParamSpec ,
12+ TypeVar ,
13+ AsyncIterator ,
14+ AsyncIterable ,
15+ Awaitable ,
16+ Generator ,
17+ Iterable ,
18+ )
1019
1120import nest_asyncio
21+ import torch
1222from icij_common .logging_utils import (
1323 DATE_FMT ,
1424 STREAM_HANDLER_FMT ,
2333from temporalio .common import SearchAttributeKey
2434from temporalio .exceptions import ApplicationError
2535
36+ from .constants import CPU , MPS , MKL , CUDA
37+ from .objects import Predicate
2638from .types_ import ProgressRateHandler , RawProgressHandler
2739
2840DependencyLabel = str | None
@@ -201,6 +213,7 @@ async def wrapper(*args, **kwargs) -> T:
201213 # recreate kwargs from pargs
202214 new_args , new_kwargs = _unpack_positional_args (args , keyword_only , params )
203215 return await activity_fn (* new_args , ** new_kwargs , ** kwargs )
216+
204217 else :
205218
206219 @wraps (activity_fn )
@@ -211,9 +224,11 @@ def wrapper(*args, **kwargs) -> T:
211224
212225 # Update the decorated function signature to appear as p-args only
213226 new_params = [
214- p .replace (kind = inspect .Parameter .POSITIONAL_OR_KEYWORD )
215- if p .kind == inspect .Parameter .KEYWORD_ONLY
216- else p
227+ (
228+ p .replace (kind = inspect .Parameter .POSITIONAL_OR_KEYWORD )
229+ if p .kind == inspect .Parameter .KEYWORD_ONLY
230+ else p
231+ )
217232 for p in params
218233 ]
219234 wrapper .__signature__ = sig .replace (parameters = new_params )
@@ -252,6 +267,7 @@ async def wrapper(*args, **kwargs) -> T:
252267 raise
253268 except Exception as e :
254269 raise fatal_error_from_exception (e ) from e
270+
255271 else :
256272
257273 @wraps (activity_fn )
@@ -374,3 +390,83 @@ def _handlers(
374390 handler .addFilter (worker_id_filter )
375391 handler .setLevel (log_level )
376392 return handlers
393+
394+
395+ # Temporal utils
396+ async def async_batches (
397+ iterable : AsyncIterable [T ], batch_size : int
398+ ) -> AsyncIterator [tuple [T ]]:
399+ it = aiter (iterable )
400+ if batch_size < 1 :
401+ raise ValueError ("n must be at least one" )
402+ while True :
403+ batch = []
404+ while len (batch ) < batch_size :
405+ try :
406+ batch .append (await anext (it ))
407+ except StopAsyncIteration :
408+ if batch :
409+ yield tuple (batch )
410+ return
411+ yield tuple (batch )
412+
413+
414+ def batches (
415+ iterable : Iterable [T ], batch_size : int
416+ ) -> Generator [tuple [T , ...], None , None ]:
417+ if batch_size < 1 :
418+ raise ValueError ("n must be at least one" )
419+ it = iter (iterable )
420+ while batch := tuple (islice (it , batch_size )):
421+ yield batch
422+
423+
424+ async def maybe_await (maybe_awaitable : Awaitable [T ] | T ) -> T :
425+ if inspect .isawaitable (maybe_awaitable ):
426+ return await maybe_awaitable
427+ return maybe_awaitable
428+
429+
430+ async def once (item : T ) -> AsyncIterator [T ]:
431+ yield item
432+
433+
434+ def before_and_after (
435+ iterable : AsyncIterable [T ], predicate : Predicate [T ]
436+ ) -> tuple [AsyncIterable [T ], AsyncIterable [T ]]:
437+ transition = asyncio .get_event_loop ().create_future ()
438+
439+ async def true_iterator () -> AsyncIterator [T ]:
440+ async for elem in iterable :
441+ if await maybe_await (predicate (elem )):
442+ yield elem
443+ else :
444+ transition .set_result (elem )
445+ return
446+ transition .set_exception (StopAsyncIteration )
447+
448+ async def remainder_iterator () -> AsyncIterator [T ]:
449+ try :
450+ yield await transition
451+ except StopAsyncIteration :
452+ return
453+ async for elm in iterable :
454+ yield elm
455+
456+ return true_iterator (), remainder_iterator ()
457+
458+
459+ # Torch utils
460+ def find_device (device_name : str = CPU ) -> torch .Device :
461+ """Find a device by name if available
462+
463+ :param device_name: Device name
464+ :return: torch.Device
465+ """
466+ if (
467+ hasattr (torch .backends , device_name )
468+ and getattr (torch .backends , device_name ).is_available ()
469+ ):
470+ return torch .device (device_name )
471+
472+ return torch .device (CPU )
0 commit comments