44
55import functools
66import logging
7- import sys
87import warnings
8+ from collections .abc import Callable
99from collections .abc import Iterable
1010from collections .abc import Mapping
1111from functools import partial
1212from typing import Any
13- from typing import Callable
1413from typing import cast
14+ from typing import ParamSpec
1515from typing import TypeVar
1616
17- if sys .version_info >= (3 , 10 ): # pragma: >3.10 cover
18- from typing import ParamSpec
19- else : # pragma: <3.10 cover
20- from typing_extensions import ParamSpec
21-
2217try :
18+ from dask ._task_spec import DataNode
2319 from dask .base import tokenize
2420 from dask .utils import funcname
2521 from distributed import Client as DaskDistributedClient
4036from proxystore .store .utils import get_key
4137from proxystore .warnings import ExperimentalWarning
4238
43- try : # pragma: >3.9 cover
44- from dask ._task_spec import DataNode
45-
46- class _ProxyNode (DataNode ):
47- key : ConnectorKeyT
48- value : Proxy [Any ]
49-
50- USE_TASK_SPEC = True
51- except ImportError : # pragma: <=3.9 cover
52- USE_TASK_SPEC = False
53-
5439warnings .warn (
5540 'Dask plugins are an experimental feature and may exhibit unexpected '
5641 'behaviour or change in the future.' ,
@@ -65,6 +50,11 @@ class _ProxyNode(DataNode):
6550logger = logging .getLogger (__name__ )
6651
6752
53+ class _ProxyNode (DataNode ):
54+ key : ConnectorKeyT
55+ value : Proxy [Any ]
56+
57+
6858class Client (DaskDistributedClient ):
6959 """Dask Distributed Client with ProxyStore support.
7060
@@ -211,11 +201,11 @@ def map( # type: ignore[no-untyped-def]
211201 # and instead want to wait to proxy until the later calls to map()
212202 # on each batch.
213203 key = key or funcname (func )
214- iterables = list (zip (* zip (* iterables ) )) # type: ignore[assignment]
204+ iterables = list (zip (* zip (* iterables , strict = False ), strict = False )) # type: ignore[assignment]
215205 if not isinstance (key , list ) and pure : # pragma: no branch
216206 key = [
217207 f'{ key } -{ tokenize (func , kwargs , * args )} -proxy'
218- for args in zip (* iterables )
208+ for args in zip (* iterables , strict = False )
219209 ]
220210
221211 iterables = tuple (
@@ -265,7 +255,7 @@ def map( # type: ignore[no-untyped-def]
265255 not (batch_size and batch_size > 1 and total_length > batch_size )
266256 and self ._ps_store is not None
267257 ):
268- for future , * args in zip (futures , * iterables ):
258+ for future , * args in zip (futures , * iterables , strict = False ):
269259 # TODO: how to delete kwargs?
270260 callback = partial (
271261 _evict_proxies_callback ,
@@ -385,7 +375,7 @@ def submit( # type: ignore[no-untyped-def]
385375
386376
387377def _evict_proxies_callback (
388- _future : DaskDistributedFuture ,
378+ _future : DaskDistributedFuture [ Any ] ,
389379 keys : Iterable [ConnectorKeyT ],
390380 store : Store [Any ],
391381) -> None :
@@ -394,17 +384,11 @@ def _evict_proxies_callback(
394384
395385
396386def _get_keys (iterable : Iterable [Any ]) -> tuple [ConnectorKeyT , ...]:
397- if USE_TASK_SPEC : # pragma: >3.9 cover
398- return tuple (x .key for x in iterable if isinstance (x , _ProxyNode ))
399- else : # pragma: <=3.9 cover
400- return tuple (x for x in iterable if isinstance (x , Proxy ))
387+ return tuple (x .key for x in iterable if isinstance (x , _ProxyNode ))
401388
402389
403390def _is_proxy (obj : Any ) -> bool :
404- if USE_TASK_SPEC : # pragma: >3.9 cover
405- return isinstance (obj , (_ProxyNode , Proxy ))
406- else : # pragma: <=3.9 cover
407- return isinstance (obj , Proxy )
391+ return isinstance (obj , _ProxyNode | Proxy )
408392
409393
410394def _proxy_by_size (
@@ -495,13 +479,10 @@ def _proxy_iterable(
495479 for value in iterable
496480 )
497481
498- if USE_TASK_SPEC : # pragma: >3.9 cover
499- return tuple (
500- _ProxyNode (get_key (obj ), obj ) if isinstance (obj , Proxy ) else obj
501- for obj in objects
502- )
503- else : # pragma: <=3.9 cover
504- return objects
482+ return tuple (
483+ _ProxyNode (get_key (obj ), obj ) if isinstance (obj , Proxy ) else obj
484+ for obj in objects
485+ )
505486
506487
507488def _proxy_mapping (
@@ -533,15 +514,10 @@ def _proxy_mapping(
533514 for key in mapping
534515 }
535516
536- if USE_TASK_SPEC : # pragma: >3.9 cover
537- return {
538- key : _ProxyNode (get_key (obj ), obj )
539- if isinstance (obj , Proxy )
540- else obj
541- for key , obj in objects .items ()
542- }
543- else : # pragma: <=3.9 cover
544- return objects
517+ return {
518+ key : _ProxyNode (get_key (obj ), obj ) if isinstance (obj , Proxy ) else obj
519+ for key , obj in objects .items ()
520+ }
545521
546522
547523def _proxy_task_wrapper (
0 commit comments