1919 create_pool ,
2020 decode_hosts ,
2121)
22- from typing import TYPE_CHECKING , Dict , List , Tuple , Union , Optional
2322
24- if TYPE_CHECKING :
23+ import typing
24+
25+ if typing .TYPE_CHECKING :
2526 from redis .asyncio .connection import ConnectionPool
2627 from redis .asyncio .client import Redis
27- from .core import RedisChannelLayer
2828 from typing_extensions import Buffer
2929
3030logger = logging .getLogger (__name__ )
@@ -39,10 +39,10 @@ class ChannelLock:
3939 """
4040
4141 def __init__ (self ):
42- self .locks : " collections.defaultdict[str, asyncio.Lock]" = (
42+ self .locks : collections .defaultdict [str , asyncio .Lock ] = (
4343 collections .defaultdict (asyncio .Lock )
4444 )
45- self .wait_counts : " collections.defaultdict[str, int]" = collections .defaultdict (
45+ self .wait_counts : collections .defaultdict [str , int ] = collections .defaultdict (
4646 int
4747 )
4848
@@ -87,7 +87,7 @@ class RedisLoopLayer:
8787 def __init__ (self , channel_layer : "RedisChannelLayer" ):
8888 self ._lock = asyncio .Lock ()
8989 self .channel_layer = channel_layer
90- self ._connections : " Dict[int, Redis]" = {}
90+ self ._connections : typing . Dict [int , " Redis" ] = {}
9191
9292 def get_connection (self , index : int ) -> "Redis" :
9393 if index not in self ._connections :
@@ -145,7 +145,7 @@ def __init__(
145145 symmetric_encryption_keys = symmetric_encryption_keys ,
146146 )
147147 # Cached redis connection pools and the event loop they are from
148- self ._layers : " Dict[asyncio.AbstractEventLoop, RedisLoopLayer]" = {}
148+ self ._layers : typing . Dict [asyncio .AbstractEventLoop , " RedisLoopLayer" ] = {}
149149 # Normal channels choose a host index by cycling through the available hosts
150150 self ._receive_index_generator = itertools .cycle (range (len (self .hosts )))
151151 self ._send_index_generator = itertools .cycle (range (len (self .hosts )))
@@ -154,15 +154,15 @@ def __init__(
154154 # Number of coroutines trying to receive right now
155155 self .receive_count = 0
156156 # The receive lock
157- self .receive_lock : " Optional[asyncio.Lock]" = None
157+ self .receive_lock : typing . Optional [asyncio .Lock ] = None
158158 # Event loop they are trying to receive on
159- self .receive_event_loop : " Optional[asyncio.AbstractEventLoop]" = None
159+ self .receive_event_loop : typing . Optional [asyncio .AbstractEventLoop ] = None
160160 # Buffered messages by process-local channel name
161- self .receive_buffer : " collections.defaultdict[str, BoundedQueue]" = (
161+ self .receive_buffer : collections .defaultdict [str , BoundedQueue ] = (
162162 collections .defaultdict (functools .partial (BoundedQueue , self .capacity ))
163163 )
164164 # Detached channel cleanup tasks
165- self .receive_cleaners : " List[asyncio.Task]" = []
165+ self .receive_cleaners : typing . List [asyncio .Task ] = []
166166 # Per-channel cleanup locks to prevent a receive starting and moving
167167 # a message back into the main queue before its cleanup has completed
168168 self .receive_clean_locks = ChannelLock ()
@@ -180,7 +180,7 @@ async def send(self, channel: str, message):
180180 """
181181 # Typecheck
182182 assert isinstance (message , dict ), "message is not a dict"
183- assert self .require_valid_channel_name (channel ), "Channel name not valid"
183+ assert self .valid_channel_name (channel ), "Channel name not valid"
184184 # Make sure the message does not contain reserved keys
185185 assert "__asgi_channel__" not in message
186186 # If it's a process-local channel, strip off local part and stick full name in message
@@ -221,7 +221,7 @@ def _backup_channel_name(self, channel: str) -> str:
221221 return channel + "$inflight"
222222
223223 async def _brpop_with_clean (
224- self , index : int , channel : str , timeout : " Union[int, float, bytes, str]"
224+ self , index : int , channel : str , timeout : typing . Union [int , float , bytes , str ]
225225 ):
226226 """
227227 Perform a Redis BRPOP and manage the backup processing queue.
@@ -269,7 +269,7 @@ async def receive(self, channel: str):
269269 """
270270 # Make sure the channel name is valid then get the non-local part
271271 # and thus its index
272- assert self .require_valid_channel_name (channel )
272+ assert self .valid_channel_name (channel )
273273 if "!" in channel :
274274 real_channel = self .non_local_name (channel )
275275 assert real_channel .endswith (
@@ -385,14 +385,12 @@ async def receive(self, channel: str):
385385 # Do a plain direct receive
386386 return (await self .receive_single (channel ))[1 ]
387387
388- async def receive_single (self , channel : str ) -> " Tuple" :
388+ async def receive_single (self , channel : str ) -> typing . Tuple :
389389 """
390390 Receives a single message off of the channel and returns it.
391391 """
392392 # Check channel name
393- assert self .require_valid_channel_name (
394- channel , receive = True
395- ), "Channel name invalid"
393+ assert self .valid_channel_name (channel , receive = True ), "Channel name invalid"
396394 # Work out the connection to use
397395 if "!" in channel :
398396 assert channel .endswith ("!" )
@@ -423,7 +421,7 @@ async def receive_single(self, channel: str) -> "Tuple":
423421 )
424422 self .receive_cleaners .append (cleaner )
425423
426- def _cleanup_done (cleaner : " asyncio.Task" ):
424+ def _cleanup_done (cleaner : asyncio .Task ):
427425 self .receive_cleaners .remove (cleaner )
428426 self .receive_clean_locks .release (channel_key )
429427
@@ -497,8 +495,8 @@ async def group_add(self, group: str, channel: str):
497495 Adds the channel name to a group.
498496 """
499497 # Check the inputs
500- assert self .require_valid_group_name (group ), True
501- assert self .require_valid_channel_name (channel ), True
498+ assert self .valid_group_name (group ), True
499+ assert self .valid_channel_name (channel ), True
502500 # Get a connection to the right shard
503501 group_key = self ._group_key (group )
504502 connection = self .connection (self .consistent_hash (group ))
@@ -513,8 +511,8 @@ async def group_discard(self, group: str, channel: str):
513511 Removes the channel from the named group if it is in the group;
514512 does nothing otherwise (does not error)
515513 """
516- assert self .require_valid_group_name (group ), "Group name not valid"
517- assert self .require_valid_channel_name (channel ), "Channel name not valid"
514+ assert self .valid_group_name (group ), "Group name not valid"
515+ assert self .valid_channel_name (channel ), "Channel name not valid"
518516 key = self ._group_key (group )
519517 connection = self .connection (self .consistent_hash (group ))
520518 await connection .zrem (key , channel )
@@ -523,7 +521,7 @@ async def group_send(self, group: str, message):
523521 """
524522 Sends a message to the entire group.
525523 """
526- assert self .require_valid_group_name (group ), "Group name not valid"
524+ assert self .valid_group_name (group ), "Group name not valid"
527525 # Retrieve list of all channel names
528526 key = self ._group_key (group )
529527 connection = self .connection (self .consistent_hash (group ))
@@ -671,7 +669,7 @@ def deserialize(self, message: bytes):
671669
672670 ### Internal functions ###
673671
674- def consistent_hash (self , value : " Union[str, Buffer]" ) -> int :
672+ def consistent_hash (self , value : typing . Union [str , " Buffer" ] ) -> int :
675673 return _consistent_hash (value , self .ring_size )
676674
677675 def __str__ (self ):
0 commit comments