99
1010from redis import asyncio as aioredis
1111
12- from channels .exceptions import ChannelFull
13- from channels .layers import BaseChannelLayer
12+ from channels .exceptions import ChannelFull # type: ignore[import-untyped]
13+ from channels .layers import BaseChannelLayer # type: ignore[import-untyped]
1414
1515from .serializers import registry
1616from .utils import (
@@ -37,13 +37,11 @@ class ChannelLock:
3737 to mitigate multi-event loop problems.
3838 """
3939
40- def __init__ (self ):
41- self .locks : collections .defaultdict [str , asyncio .Lock ] = (
42- collections .defaultdict (asyncio .Lock )
43- )
44- self .wait_counts : collections .defaultdict [str , int ] = collections .defaultdict (
45- int
40+ def __init__ (self ) -> None :
41+ self .locks : typing .DefaultDict [str , asyncio .Lock ] = collections .defaultdict (
42+ asyncio .Lock
4643 )
44+ self .wait_counts : typing .DefaultDict [str , int ] = collections .defaultdict (int )
4745
4846 async def acquire (self , channel : str ) -> bool :
4947 """
@@ -58,7 +56,7 @@ def locked(self, channel: str) -> bool:
5856 """
5957 return self .locks [channel ].locked ()
6058
61- def release (self , channel : str ):
59+ def release (self , channel : str ) -> None :
6260 """
6361 Release the lock for the given channel.
6462 """
@@ -70,7 +68,7 @@ def release(self, channel: str):
7068
7169
7270class BoundedQueue (asyncio .Queue ):
73- def put_nowait (self , item ) :
71+ def put_nowait (self , item : typing . Any ) -> None :
7472 if self .full ():
7573 # see: https://github.com/django/channels_redis/issues/212
7674 # if we actually get into this code block, it likely means that
@@ -83,7 +81,7 @@ def put_nowait(self, item):
8381
8482
8583class RedisLoopLayer :
86- def __init__ (self , channel_layer : "RedisChannelLayer" ):
84+ def __init__ (self , channel_layer : "RedisChannelLayer" ) -> None :
8785 self ._lock = asyncio .Lock ()
8886 self .channel_layer = channel_layer
8987 self ._connections : typing .Dict [int , "Redis" ] = {}
@@ -95,7 +93,7 @@ def get_connection(self, index: int) -> "Redis":
9593
9694 return self ._connections [index ]
9795
98- async def flush (self ):
96+ async def flush (self ) -> None :
9997 async with self ._lock :
10098 for index in list (self ._connections ):
10199 connection = self ._connections .pop (index )
@@ -116,15 +114,15 @@ class RedisChannelLayer(BaseChannelLayer):
116114 def __init__ (
117115 self ,
118116 hosts = None ,
119- prefix = "asgi" ,
117+ prefix : str = "asgi" ,
120118 expiry = 60 ,
121- group_expiry = 86400 ,
119+ group_expiry : int = 86400 ,
122120 capacity = 100 ,
123121 channel_capacity = None ,
124122 symmetric_encryption_keys = None ,
125123 random_prefix_length = 12 ,
126124 serializer_format = "msgpack" ,
127- ):
125+ ) -> None :
128126 # Store basic information
129127 self .expiry = expiry
130128 self .group_expiry = group_expiry
@@ -157,11 +155,11 @@ def __init__(
157155 # Event loop they are trying to receive on
158156 self .receive_event_loop : typing .Optional [asyncio .AbstractEventLoop ] = None
159157 # Buffered messages by process-local channel name
160- self .receive_buffer : collections . defaultdict [str , BoundedQueue ] = (
158+ self .receive_buffer : typing . DefaultDict [str , BoundedQueue ] = (
161159 collections .defaultdict (functools .partial (BoundedQueue , self .capacity ))
162160 )
163161 # Detached channel cleanup tasks
164- self .receive_cleaners : typing .List [asyncio .Task ] = []
162+ self .receive_cleaners : typing .List [" asyncio.Task[typing.Any]" ] = []
165163 # Per-channel cleanup locks to prevent a receive starting and moving
166164 # a message back into the main queue before its cleanup has completed
167165 self .receive_clean_locks = ChannelLock ()
@@ -173,7 +171,7 @@ def create_pool(self, index: int) -> "ConnectionPool":
173171
174172 extensions = ["groups" , "flush" ]
175173
176- async def send (self , channel : str , message ) :
174+ async def send (self , channel : str , message : typing . Any ) -> None :
177175 """
178176 Send a message onto a (general or specific) channel.
179177 """
@@ -221,7 +219,7 @@ def _backup_channel_name(self, channel: str) -> str:
221219
222220 async def _brpop_with_clean (
223221 self , index : int , channel : str , timeout : typing .Union [int , float , bytes , str ]
224- ):
222+ ) -> typing . Any :
225223 """
226224 Perform a Redis BRPOP and manage the backup processing queue.
227225 In case of cancellation, make sure the message is not lost.
@@ -240,7 +238,7 @@ async def _brpop_with_clean(
240238 connection = self .connection (index )
241239 # Cancellation here doesn't matter, we're not doing anything destructive
242240 # and the script executes atomically...
243- await connection .eval (cleanup_script , 0 , channel , backup_queue )
241+ await connection .eval (cleanup_script , 0 , channel , backup_queue ) # type: ignore[misc]
244242 # ...and it doesn't matter here either, the message will be safe in the backup.
245243 result = await connection .bzpopmin (channel , timeout = timeout )
246244
@@ -252,15 +250,15 @@ async def _brpop_with_clean(
252250
253251 return member
254252
255- async def _clean_receive_backup (self , index : int , channel : str ):
253+ async def _clean_receive_backup (self , index : int , channel : str ) -> None :
256254 """
257255 Pop the oldest message off the channel backup queue.
258256 The result isn't interesting as it was already processed.
259257 """
260258 connection = self .connection (index )
261259 await connection .zpopmin (self ._backup_channel_name (channel ))
262260
263- async def receive (self , channel : str ):
261+ async def receive (self , channel : str ) -> typing . Any :
264262 """
265263 Receive the first message that arrives on the channel.
266264 If more than one coroutine waits on the same channel, the first waiter
@@ -292,11 +290,11 @@ async def receive(self, channel: str):
292290 # Wait for our message to appear
293291 message = None
294292 while self .receive_buffer [channel ].empty ():
295- tasks = [
296- self .receive_lock .acquire (),
293+ _tasks = [
294+ self .receive_lock .acquire (), # type: ignore[union-attr]
297295 self .receive_buffer [channel ].get (),
298296 ]
299- tasks = [asyncio .ensure_future (task ) for task in tasks ]
297+ tasks = [asyncio .ensure_future (task ) for task in _tasks ]
300298 try :
301299 done , pending = await asyncio .wait (
302300 tasks , return_when = asyncio .FIRST_COMPLETED
@@ -312,7 +310,7 @@ async def receive(self, channel: str):
312310 if not task .cancel ():
313311 assert task .done ()
314312 if task .result () is True :
315- self .receive_lock .release ()
313+ self .receive_lock .release () # type: ignore[union-attr]
316314
317315 raise
318316
@@ -335,7 +333,7 @@ async def receive(self, channel: str):
335333 if message or exception :
336334 if token :
337335 # We will not be receving as we already have the message.
338- self .receive_lock .release ()
336+ self .receive_lock .release () # type: ignore[union-attr]
339337
340338 if exception :
341339 raise exception
@@ -362,7 +360,7 @@ async def receive(self, channel: str):
362360 del self .receive_buffer [channel ]
363361 raise
364362 finally :
365- self .receive_lock .release ()
363+ self .receive_lock .release () # type: ignore[union-attr]
366364
367365 # We know there's a message available, because there
368366 # couldn't have been any interruption between empty() and here
@@ -377,14 +375,16 @@ async def receive(self, channel: str):
377375 self .receive_count -= 1
378376 # If we were the last out, drop the receive lock
379377 if self .receive_count == 0 :
380- assert not self .receive_lock .locked ()
378+ assert not self .receive_lock .locked () # type: ignore[union-attr]
381379 self .receive_lock = None
382380 self .receive_event_loop = None
383381 else :
384382 # Do a plain direct receive
385383 return (await self .receive_single (channel ))[1 ]
386384
387- async def receive_single (self , channel : str ) -> typing .Tuple :
385+ async def receive_single (
386+ self , channel : str
387+ ) -> typing .Tuple [typing .Any , typing .Any ]:
388388 """
389389 Receives a single message off of the channel and returns it.
390390 """
@@ -420,7 +420,7 @@ async def receive_single(self, channel: str) -> typing.Tuple:
420420 )
421421 self .receive_cleaners .append (cleaner )
422422
423- def _cleanup_done (cleaner : asyncio .Task ) :
423+ def _cleanup_done (cleaner : " asyncio.Task" ) -> None :
424424 self .receive_cleaners .remove (cleaner )
425425 self .receive_clean_locks .release (channel_key )
426426
@@ -448,7 +448,7 @@ async def new_channel(self, prefix: str = "specific") -> str:
448448
449449 ### Flush extension ###
450450
451- async def flush (self ):
451+ async def flush (self ) -> None :
452452 """
453453 Deletes all messages and groups on all shards.
454454 """
@@ -466,11 +466,11 @@ async def flush(self):
466466 # Go through each connection and remove all with prefix
467467 for i in range (self .ring_size ):
468468 connection = self .connection (i )
469- await connection .eval (delete_prefix , 0 , self .prefix + "*" )
469+ await connection .eval (delete_prefix , 0 , self .prefix + "*" ) # type: ignore[union-attr,misc]
470470 # Now clear the pools as well
471471 await self .close_pools ()
472472
473- async def close_pools (self ):
473+ async def close_pools (self ) -> None :
474474 """
475475 Close all connections in the event loop pools.
476476 """
@@ -480,7 +480,7 @@ async def close_pools(self):
480480 for layer in self ._layers .values ():
481481 await layer .flush ()
482482
483- async def wait_received (self ):
483+ async def wait_received (self ) -> None :
484484 """
485485 Wait for all channel cleanup functions to finish.
486486 """
@@ -489,13 +489,13 @@ async def wait_received(self):
489489
490490 ### Groups extension ###
491491
492- async def group_add (self , group : str , channel : str ):
492+ async def group_add (self , group : str , channel : str ) -> None :
493493 """
494494 Adds the channel name to a group.
495495 """
496496 # Check the inputs
497- assert self .valid_group_name (group ), True
498- assert self .valid_channel_name (channel ), True
497+ assert self .valid_group_name (group ), "Group name not valid"
498+ assert self .valid_channel_name (channel ), "Channel name not valid"
499499 # Get a connection to the right shard
500500 group_key = self ._group_key (group )
501501 connection = self .connection (self .consistent_hash (group ))
@@ -505,7 +505,7 @@ async def group_add(self, group: str, channel: str):
505505 # it at this point is guaranteed to expire before that
506506 await connection .expire (group_key , self .group_expiry )
507507
508- async def group_discard (self , group : str , channel : str ):
508+ async def group_discard (self , group : str , channel : str ) -> None :
509509 """
510510 Removes the channel from the named group if it is in the group;
511511 does nothing otherwise (does not error)
@@ -516,7 +516,7 @@ async def group_discard(self, group: str, channel: str):
516516 connection = self .connection (self .consistent_hash (group ))
517517 await connection .zrem (key , channel )
518518
519- async def group_send (self , group : str , message ) :
519+ async def group_send (self , group : str , message : typing . Any ) -> None :
520520 """
521521 Sends a message to the entire group.
522522 """
@@ -540,9 +540,9 @@ async def group_send(self, group: str, message):
540540 for connection_index , channel_redis_keys in connection_to_channel_keys .items ():
541541 # Discard old messages based on expiry
542542 pipe = connection .pipeline ()
543- for key in channel_redis_keys :
543+ for _key in channel_redis_keys :
544544 pipe .zremrangebyscore (
545- key , min = 0 , max = int (time .time ()) - int (self .expiry )
545+ _key , min = 0 , max = int (time .time ()) - int (self .expiry )
546546 )
547547 await pipe .execute ()
548548
@@ -582,10 +582,10 @@ async def group_send(self, group: str, message):
582582
583583 # channel_keys does not contain a single redis key more than once
584584 connection = self .connection (connection_index )
585- channels_over_capacity = await connection .eval (
585+ channels_over_capacity = await connection .eval ( # type: ignore[misc]
586586 group_send_lua , len (channel_redis_keys ), * channel_redis_keys , * args
587587 )
588- _channels_over_capacity = - 1
588+ _channels_over_capacity = - 1.0
589589 try :
590590 _channels_over_capacity = float (channels_over_capacity )
591591 except Exception :
@@ -598,7 +598,13 @@ async def group_send(self, group: str, message):
598598 group ,
599599 )
600600
601- def _map_channel_keys_to_connection (self , channel_names , message ):
601+ def _map_channel_keys_to_connection (
602+ self , channel_names : typing .Iterable [str ], message : typing .Any
603+ ) -> typing .Tuple [
604+ typing .Dict [int , typing .List [str ]],
605+ typing .Dict [str , typing .Any ],
606+ typing .Dict [str , int ],
607+ ]:
602608 """
603609 For a list of channel names, GET
604610
@@ -611,19 +617,21 @@ def _map_channel_keys_to_connection(self, channel_names, message):
611617 """
612618
613619 # Connection dict keyed by index to list of redis keys mapped on that index
614- connection_to_channel_keys = collections .defaultdict (list )
620+ connection_to_channel_keys : typing .Dict [int , typing .List [str ]] = (
621+ collections .defaultdict (list )
622+ )
615623 # Message dict maps redis key to the message that needs to be send on that key
616- channel_key_to_message = dict ()
624+ channel_key_to_message : typing . Dict [ str , typing . Any ] = dict ()
617625 # Channel key mapped to its capacity
618- channel_key_to_capacity = dict ()
626+ channel_key_to_capacity : typing . Dict [ str , int ] = dict ()
619627
620628 # For each channel
621629 for channel in channel_names :
622630 channel_non_local_name = channel
623631 if "!" in channel :
624632 channel_non_local_name = self .non_local_name (channel )
625633 # Get its redis key
626- channel_key = self .prefix + channel_non_local_name
634+ channel_key : str = self .prefix + channel_non_local_name
627635 # Have we come across the same redis key?
628636 if channel_key not in channel_key_to_message :
629637 # If not, fill the corresponding dicts
@@ -654,13 +662,15 @@ def _group_key(self, group: str) -> bytes:
654662 """
655663 return f"{ self .prefix } :group:{ group } " .encode ("utf8" )
656664
657- def serialize (self , message ) -> bytes :
665+ ### Serialization ###
666+
667+ def serialize (self , message : typing .Any ) -> bytes :
658668 """
659669 Serializes message to a byte string.
660670 """
661671 return self ._serializer .serialize (message )
662672
663- def deserialize (self , message : bytes ):
673+ def deserialize (self , message : bytes ) -> typing . Any :
664674 """
665675 Deserializes from a byte string.
666676 """
@@ -671,7 +681,7 @@ def deserialize(self, message: bytes):
671681 def consistent_hash (self , value : typing .Union [str , "Buffer" ]) -> int :
672682 return _consistent_hash (value , self .ring_size )
673683
674- def __str__ (self ):
684+ def __str__ (self ) -> str :
675685 return f"{ self .__class__ .__name__ } (hosts={ self .hosts } )"
676686
677687 ### Connection handling ###
0 commit comments