Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES/11969.doc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Added rate-limiting client middleware example (``examples/rate_limit_middleware.py``)
demonstrating token-bucket rate limiting with per-domain support and ``Retry-After``
header handling -- by :user:`rodrigobnogueira`.
168 changes: 168 additions & 0 deletions examples/rate_limit_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
#!/usr/bin/env python3
"""
Client-side rate-limiting middleware example for aiohttp.

Demonstrates how to throttle outgoing requests using a token-bucket
algorithm. This is *not* server-side rate limiting — it limits how
fast the client sends requests so it does not overwhelm upstream
servers or exceed API quotas.

Features:
- Configurable rate and burst size
- Optional per-domain buckets
- Automatic ``Retry-After`` header handling
"""

import asyncio
import logging
import time
from collections import defaultdict, deque
from http import HTTPStatus

from aiohttp import ClientHandlerType, ClientRequest, ClientResponse, ClientSession, web

logging.basicConfig(level=logging.INFO)
_LOGGER = logging.getLogger(__name__)


class TokenBucket:
"""FIFO token-bucket using an ``asyncio.Event`` queue.

Each caller appends its own event to a FIFO queue and waits.
A single ``_schedule`` coroutine services the queue front-to-back,
sleeping until each slot's send time arrives and then unblocking
the corresponding caller. This guarantees strict FIFO ordering
even under high concurrency.
"""

def __init__(self, rate: float, burst: int) -> None:
self._interval = 1.0 / rate
self._burst = burst
# Start *burst* intervals in the past so the first
# ``burst`` acquires are instant.
self._next_send = time.monotonic() - burst * self._interval
self._waiters: deque[asyncio.Event] = deque()
self._scheduling: bool = False

async def acquire(self) -> None:
"""Reserve the next send slot and wait until it arrives."""
event = asyncio.Event()
self._waiters.append(event)
self._ensure_scheduling()
await event.wait()

def _ensure_scheduling(self) -> None:
"""Start the scheduler loop if it is not already running."""
if not self._scheduling:
self._scheduling = True
_ = asyncio.ensure_future(self._schedule())

async def _schedule(self) -> None:
"""Service waiters in FIFO order, one slot at a time."""
while self._waiters:
now = time.monotonic()
# Cap drift so idle periods never accumulate
# more than *burst* free slots.
self._next_send = max(self._next_send, now - self._burst * self._interval)
self._next_send += self._interval
delay = self._next_send - now
if delay > 0:
await asyncio.sleep(delay)
self._waiters.popleft().set()
self._scheduling = False


class RateLimitMiddleware:
"""Middleware that rate limits requests using token bucket algorithm."""

rate: float
burst: int
per_domain: bool
respect_retry_after: bool

def __init__(
self,
rate: float = 10.0,
burst: int = 10,
per_domain: bool = False,
respect_retry_after: bool = True,
) -> None:
self.rate = rate
self.burst = burst
self.per_domain = per_domain
self.respect_retry_after = respect_retry_after
self._global_bucket = TokenBucket(rate, burst)
self._domain_buckets: dict[str, TokenBucket] = defaultdict(
lambda: TokenBucket(rate, burst)
)

def _get_bucket(self, request: ClientRequest) -> TokenBucket:
if self.per_domain:
domain = request.url.host or "unknown"
return self._domain_buckets[domain]
return self._global_bucket

async def _handle_retry_after(self, response: ClientResponse) -> None:
if response.status != HTTPStatus.TOO_MANY_REQUESTS:
return
retry_after = response.headers.get("Retry-After")
if retry_after:
try:
wait_seconds = float(retry_after)
_LOGGER.info("Server requested Retry-After: %ss", wait_seconds)
await asyncio.sleep(wait_seconds)
except ValueError:
_LOGGER.debug(
"Retry-After is not a number (likely HTTP-date): %s", retry_after
)

async def __call__(
self,
request: ClientRequest,
handler: ClientHandlerType,
) -> ClientResponse:
"""Execute request with rate limiting."""
bucket = self._get_bucket(request)
await bucket.acquire()

response = await handler(request)

if self.respect_retry_after:
await self._handle_retry_after(response)

return response


# ------------------------------------------------------------------
# Self-contained demo (no external dependencies)
async def _demo_handler(_request: web.Request) -> web.Response:
return web.Response(text="OK")


async def main() -> None:
app = web.Application()
_ = app.router.add_get("/get", _demo_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", 0)
await site.start()

port: int = site.port
rate_limit = RateLimitMiddleware(rate=5.0, burst=2)
start = time.monotonic()

try:
async with ClientSession(
base_url=f"http://127.0.0.1:{port}",
middlewares=(rate_limit,),
) as session:
for i in range(5):
async with session.get("/get") as resp:
elapsed = time.monotonic() - start
print(f"Request {i + 1}: {resp.status} at t={elapsed:.2f}s")
finally:
await runner.cleanup()


if __name__ == "__main__":
asyncio.run(main())
2 changes: 2 additions & 0 deletions examples/tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
pythonpath = ..
137 changes: 137 additions & 0 deletions examples/tests/test_rate_limit_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
Tests for the rate_limit_middleware.py example.

Run with:
pytest examples/tests/test_rate_limit_middleware.py -v
"""

import asyncio
import time

import pytest
from rate_limit_middleware import RateLimitMiddleware, TokenBucket

from aiohttp import web
from aiohttp.pytest_plugin import AiohttpClient


async def _ok_handler(request: web.Request) -> web.Response:
return web.Response(text="OK")


def _make_app() -> web.Application:
app = web.Application()
app.router.add_get("/api", _ok_handler)
return app


@pytest.mark.asyncio
async def test_token_bucket_allows_burst() -> None:
"""Tokens up to burst size should be available immediately."""
bucket = TokenBucket(rate=10.0, burst=3)
start = time.monotonic()
for _ in range(3):
await bucket.acquire()
elapsed = time.monotonic() - start
# All three should be nearly instant (burst)
assert elapsed < 0.05


@pytest.mark.asyncio
async def test_token_bucket_refills_after_idle() -> None:
"""After draining, idle time should replenish burst slots."""
bucket = TokenBucket(rate=100.0, burst=1)
await bucket.acquire()
await asyncio.sleep(0.05)
start = time.monotonic()
await bucket.acquire()
elapsed = time.monotonic() - start
# Should be near-instant because idle refilled the slot
assert elapsed < 0.02


@pytest.mark.asyncio
async def test_token_bucket_fifo_ordering() -> None:
"""Concurrent acquires should be served in FIFO order."""
bucket = TokenBucket(rate=100.0, burst=1)
order: list[int] = []

async def numbered_acquire(n: int) -> None:
await bucket.acquire()
order.append(n)

tasks = [asyncio.create_task(numbered_acquire(i)) for i in range(3)]
await asyncio.gather(*tasks)
assert order == [0, 1, 2]


@pytest.mark.asyncio
async def test_rate_limit_middleware_throttles(
aiohttp_client: AiohttpClient,
) -> None:
"""Global middleware should throttle requests beyond burst."""
middleware = RateLimitMiddleware(rate=50.0, burst=2)
client = await aiohttp_client(_make_app(), middlewares=(middleware,))

start = time.monotonic()
for _ in range(4):
resp = await client.get("/api")
assert resp.status == 200
elapsed = time.monotonic() - start

# 2 burst + 2 throttled at 50/s ≈ 0.04s minimum wait.
# Upper bound (0.5s) catches hangs or accidental double-sleeps
# while staying generous enough for CI environments.
assert 0.02 <= elapsed < 0.5


@pytest.mark.asyncio
async def test_rate_limit_middleware_per_domain(
aiohttp_client: AiohttpClient,
) -> None:
"""Per-domain middleware should isolate buckets per host."""
middleware = RateLimitMiddleware(rate=100.0, burst=1, per_domain=True)
client = await aiohttp_client(_make_app(), middlewares=(middleware,))

start = time.monotonic()
# Same host, so they share a bucket — second request should wait
resp1 = await client.get("/api")
resp2 = await client.get("/api")
elapsed = time.monotonic() - start

assert resp1.status == 200
assert resp2.status == 200
# Upper bound catches unexpected delays without being flaky on CI
assert 0.005 <= elapsed < 0.5


@pytest.mark.asyncio
async def test_rate_limit_middleware_respects_retry_after(
aiohttp_client: AiohttpClient,
) -> None:
"""Middleware should sleep when server returns 429 + Retry-After."""
call_count = 0

async def rate_limited_handler(request: web.Request) -> web.Response:
nonlocal call_count
call_count += 1
if call_count <= 1:
return web.Response(
status=429,
headers={"Retry-After": "0.1"},
)
return web.Response(text="OK")

app = web.Application()
app.router.add_get("/api", rate_limited_handler)

middleware = RateLimitMiddleware(rate=100.0, burst=10, respect_retry_after=True)
client = await aiohttp_client(app, middlewares=(middleware,))

start = time.monotonic()
resp = await client.get("/api")
elapsed = time.monotonic() - start

assert resp.status == 429
# Upper bound catches unexpected delays without being flaky on CI
assert 0.08 <= elapsed < 0.5
Loading