Skip to content

Commit 784fa88

Browse files
authored
Add support for jitter so that we can ensure evenly distributed tasks don't cause all workers to restart at same time (#570)
1 parent eaed387 commit 784fa88

7 files changed

Lines changed: 84 additions & 2 deletions

File tree

docs/guide/cli.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ The number of signals before a hard kill can be configured with the `--hardkill-
138138
* `--log-level` is used to set a log level (default `INFO`).
139139
* `--log-format` is used to set a log format (default `%(asctime)s][%(name)s][%(levelname)-7s][%(processName)s] %(message)s`).
140140
* `--max-async-tasks` - maximum number of simultaneously running async tasks.
141+
* `--max-async-tasks-jitter` – Randomly varies the max async task limit between --max-async-tasks and a jittered value, helping prevent simultaneous worker restarts.
141142
* `--max-prefetch` - number of tasks to be prefetched before execution. (Useful for systems with high message rates, but brokers should support acknowledgements).
142143
* `--max-threadpool-threads` - number of threads for sync function execution.
143144
* `--no-propagate-errors` - if this parameter is enabled, exceptions won't be thrown in generator dependencies.

taskiq/api/receiver.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ async def run_receiver_task(
1515
sync_workers: int | None = None,
1616
validate_params: bool = True,
1717
max_async_tasks: int = 100,
18+
max_async_tasks_jitter: int = 0,
1819
max_prefetch: int = 0,
1920
propagate_exceptions: bool = True,
2021
run_startup: bool = False,
@@ -43,6 +44,7 @@ async def run_receiver_task(
4344
or processes in processpool that runs sync tasks.
4445
:param validate_params: whether to validate params or not.
4546
:param max_async_tasks: maximum number of simultaneous async tasks.
47+
:param max_async_tasks_jitter: random jitter to add to max_async_tasks.
4648
:param max_prefetch: maximum number of tasks to prefetch.
4749
:param propagate_exceptions: whether to propagate exceptions in generators or not.
4850
:param run_startup: whether to run startup function or not.
@@ -79,6 +81,7 @@ def on_exit(_: Receiver) -> None:
7981
run_startup=run_startup,
8082
validate_params=validate_params,
8183
max_async_tasks=max_async_tasks,
84+
max_async_tasks_jitter=max_async_tasks_jitter,
8285
max_prefetch=max_prefetch,
8386
propagate_exceptions=propagate_exceptions,
8487
on_exit=on_exit,

taskiq/brokers/inmemory_broker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def __init__(
127127
max_stored_results: int = 100,
128128
cast_types: bool = True,
129129
max_async_tasks: int = 30,
130+
max_async_tasks_jitter: int = 0,
130131
propagate_exceptions: bool = True,
131132
await_inplace: bool = False,
132133
) -> None:
@@ -140,6 +141,7 @@ def __init__(
140141
executor=self.executor,
141142
validate_params=cast_types,
142143
max_async_tasks=max_async_tasks,
144+
max_async_tasks_jitter=max_async_tasks_jitter,
143145
propagate_exceptions=propagate_exceptions,
144146
)
145147
self.await_inplace = await_inplace

taskiq/cli/worker/args.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class WorkerArgs:
4444
reload_dirs: list[str] = field(default_factory=list)
4545
no_gitignore: bool = False
4646
max_async_tasks: int = 100
47+
max_async_tasks_jitter: int = 0
4748
receiver: str = "taskiq.receiver:Receiver"
4849
receiver_arg: list[tuple[str, str]] = field(default_factory=list)
4950
max_prefetch: int = 0
@@ -210,6 +211,14 @@ def from_cli(
210211
default=100,
211212
help="Maximum simultaneous async tasks per worker process. ",
212213
)
214+
parser.add_argument(
215+
"--max-async-tasks-jitter",
216+
type=int,
217+
dest="max_async_tasks_jitter",
218+
default=0,
219+
help="Add random jitter (0 to this value) to max-async-tasks to prevent "
220+
"all workers from closing at the same time. ",
221+
)
213222
parser.add_argument(
214223
"--max-prefetch",
215224
type=int,

taskiq/cli/worker/run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
165165
executor=pool,
166166
validate_params=not args.no_parse,
167167
max_async_tasks=args.max_async_tasks,
168+
max_async_tasks_jitter=args.max_async_tasks_jitter,
168169
max_prefetch=args.max_prefetch,
169170
propagate_exceptions=not args.no_propagate_errors,
170171
ack_type=args.ack_type,

taskiq/receiver/receiver.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import contextvars
33
import functools
44
import inspect
5+
import random
56
import sys
67
from collections.abc import Callable
78
from concurrent.futures import Executor, ProcessPoolExecutor
@@ -55,6 +56,7 @@ def __init__(
5556
executor: Executor | None = None,
5657
validate_params: bool = True,
5758
max_async_tasks: "int | None" = None,
59+
max_async_tasks_jitter: int = 0,
5860
max_prefetch: int = 0,
5961
propagate_exceptions: bool = True,
6062
run_startup: bool = True,
@@ -80,7 +82,15 @@ def __init__(
8082
self._prepare_task(task.task_name, task.original_func)
8183
self.sem: asyncio.Semaphore | None = None
8284
if max_async_tasks is not None and max_async_tasks > 0:
83-
self.sem = asyncio.Semaphore(max_async_tasks)
85+
# Apply jitter to prevent all workers from hitting the limit simultaneously
86+
actual_limit = max_async_tasks
87+
if max_async_tasks_jitter > 0:
88+
# Using standard random for load distribution, not cryptography
89+
actual_limit = max_async_tasks + random.randint( # noqa: S311
90+
0,
91+
max_async_tasks_jitter,
92+
)
93+
self.sem = asyncio.Semaphore(actual_limit)
8494
else:
8595
logger.warning(
8696
"Setting unlimited number of async tasks "

tests/receiver/test_receiver.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import contextvars
33
import random
44
import time
5+
import unittest.mock
56
from collections.abc import Generator
67
from concurrent.futures import ThreadPoolExecutor
78
from functools import wraps
@@ -24,13 +25,15 @@ def get_receiver(
2425
broker: AsyncBroker | None = None,
2526
no_parse: bool = False,
2627
max_async_tasks: int | None = None,
28+
max_async_tasks_jitter: int = 0,
2729
) -> Receiver:
2830
"""
2931
Returns receiver with custom broker and args.
3032
3133
:param broker: broker, defaults to None
3234
:param no_parse: parameter to taskiq_args, defaults to False
33-
:param cli_args: Taskiq worker CLI arguments.
35+
:param max_async_tasks: maximum number of simultaneous async tasks.
36+
:param max_async_tasks_jitter: random jitter to add to max_async_tasks.
3437
:return: new receiver.
3538
"""
3639
if broker is None:
@@ -40,6 +43,7 @@ def get_receiver(
4043
executor=ThreadPoolExecutor(max_workers=10),
4144
validate_params=not no_parse,
4245
max_async_tasks=max_async_tasks,
46+
max_async_tasks_jitter=max_async_tasks_jitter,
4347
)
4448

4549

@@ -544,3 +548,55 @@ async def task_no_result() -> str:
544548
assert resp.return_value == "some value"
545549
assert not broker._running_tasks
546550
assert wrapper_call is True
551+
552+
553+
async def test_jitter_applied_to_semaphore() -> None:
554+
"""Test that jitter is correctly applied to max_async_tasks semaphore."""
555+
max_async_tasks = 100
556+
max_async_tasks_jitter = 10
557+
558+
# Test with jitter value of 0 (minimum)
559+
with unittest.mock.patch("random.randint", return_value=0):
560+
receiver = get_receiver(
561+
max_async_tasks=max_async_tasks,
562+
max_async_tasks_jitter=max_async_tasks_jitter,
563+
)
564+
assert receiver.sem is not None
565+
assert receiver.sem._value == max_async_tasks
566+
567+
# Test with jitter value of 5 (middle)
568+
with unittest.mock.patch("random.randint", return_value=5):
569+
receiver = get_receiver(
570+
max_async_tasks=max_async_tasks,
571+
max_async_tasks_jitter=max_async_tasks_jitter,
572+
)
573+
assert receiver.sem is not None
574+
assert receiver.sem._value == max_async_tasks + 5
575+
576+
# Test with jitter value of 10 (maximum)
577+
with unittest.mock.patch("random.randint", return_value=10):
578+
receiver = get_receiver(
579+
max_async_tasks=max_async_tasks,
580+
max_async_tasks_jitter=max_async_tasks_jitter,
581+
)
582+
assert receiver.sem is not None
583+
assert receiver.sem._value == max_async_tasks + 10
584+
585+
586+
async def test_jitter_zero_no_randomization() -> None:
587+
"""Test with zero jitter, semaphore value matches max_async_tasks."""
588+
max_async_tasks = 50
589+
590+
receiver = get_receiver(
591+
max_async_tasks=max_async_tasks,
592+
max_async_tasks_jitter=0,
593+
)
594+
595+
assert receiver.sem is not None
596+
assert receiver.sem._value == max_async_tasks
597+
598+
599+
async def test_no_semaphore_without_max_async_tasks() -> None:
600+
"""Test that semaphore is None when max_async_tasks is not set."""
601+
receiver = get_receiver(max_async_tasks=None)
602+
assert receiver.sem is None

0 commit comments

Comments
 (0)