diff --git a/pyproject.toml b/pyproject.toml index e29423b1..d118de3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "winter" -version = "30.0.1" +version = "31.0.0" homepage = "https://github.com/WinterFramework/winter" description = "Web Framework with focus on python typing, dataclasses and modular design" authors = ["Alexander Egorov "] @@ -41,6 +41,7 @@ pydantic = ">=1.10, <2" openapi-spec-validator = ">=0.5.7, <1" uritemplate = "==4.2.0" # Lib doesn't follow semantic versioning httpx = ">=0.24.1, <0.28" +redis = "^6.2.0" [tool.poetry.dev-dependencies] flake8 = ">=3.7.7, <4" @@ -61,6 +62,7 @@ pytz = ">=2020.5" [tool.poetry.group.dev.dependencies] setuptools = "^71.1.0" +testcontainers = "^4.10.0" [build-system] requires = ["poetry-core>=1.3.1"] diff --git a/tests/apps.py b/tests/apps.py index 7c205b22..9539f99c 100644 --- a/tests/apps.py +++ b/tests/apps.py @@ -1,6 +1,10 @@ +import atexit + from django.apps import AppConfig +from testcontainers.redis import RedisContainer from tests.web.interceptors import HelloWorldInterceptor +from winter.web import RedisThrottlingConfiguration from winter.web import exception_handlers_registry from winter.web import interceptor_registry from winter.web.exceptions.handlers import DefaultExceptionHandler @@ -9,6 +13,10 @@ class TestAppConfig(AppConfig): name = 'tests' + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._redis_container: RedisContainer | None = None + def ready(self): # define this import for force initialization all modules and to register Exceptions from .urls import urlpatterns # noqa: F401 @@ -19,7 +27,26 @@ def ready(self): interceptor_registry.add_interceptor(HelloWorldInterceptor()) winter_openapi.setup() + winter.web.setup() + + self._redis_container = RedisContainer() + self._redis_container.start() + self._redis_container.get_client().flushdb() + atexit.register(self.cleanup_redis) + + redis_throttling_configuration = RedisThrottlingConfiguration( + host=self._redis_container.get_container_host_ip(), + port=self._redis_container.get_exposed_port(self._redis_container.port), + db=0, + password=self._redis_container.password + ) + winter.web.set_redis_throttling_configuration(redis_throttling_configuration) + winter_django.setup() exception_handlers_registry.set_default_handler(DefaultExceptionHandler) # for 100% test coverage + + def cleanup_redis(self): # pragma: no cover + if self._redis_container: + self._redis_container.stop() diff --git a/tests/test_throttling.py b/tests/web/test_throttling.py similarity index 62% rename from tests/test_throttling.py rename to tests/web/test_throttling.py index c448c1a9..ca3c5cb8 100644 --- a/tests/test_throttling.py +++ b/tests/web/test_throttling.py @@ -3,7 +3,14 @@ import freezegun import pytest +from mock import patch +from winter.web import RedisThrottlingConfiguration +from winter.web import ThrottlingMisconfigurationException +from winter.web import set_redis_throttling_configuration +from winter.web.throttling.redis_throttling_client import get_redis_throttling_client +from winter.web.throttling import redis_throttling_client +from winter.web.throttling import redis_throttling_configuration expected_error_response = { 'status': 429, @@ -65,3 +72,35 @@ def test_get_throttling_with_conditional_reset(api_client): is_reset = True if i == 5 else False response = api_client.get(f'/with-throttling/with-reset/?is_reset={is_reset}') assert response.status_code == HTTPStatus.OK, i + + +@patch.object(redis_throttling_client, 'get_redis_throttling_configuration', return_value=None) +@patch.object(redis_throttling_client, '_redis_throttling_client', None) +def test_get_redis_throttling_client_without_configuration(_): + with pytest.raises(ThrottlingMisconfigurationException) as exc_info: + get_redis_throttling_client() + + assert 'Configuration for Redis must be set' in str(exc_info.value) + + +@patch.object( + redis_throttling_configuration, + '_redis_throttling_configuration', + RedisThrottlingConfiguration( + host='localhost', + port=1234, + db=0, + password=None + ) +) +def test_try_to_set_redis_configuration_twice(): + configuration = RedisThrottlingConfiguration( + host='localhost', + port=5678, + db=0, + password=None + ) + with pytest.raises(ThrottlingMisconfigurationException) as exc_info: + set_redis_throttling_configuration(configuration) + + assert 'RedisThrottlingConfiguration is already initialized' in str(exc_info.value) diff --git a/winter/web/__init__.py b/winter/web/__init__.py index 4372de0f..e123c493 100644 --- a/winter/web/__init__.py +++ b/winter/web/__init__.py @@ -20,6 +20,9 @@ from .response_header_resolver import ResponseHeaderArgumentResolver from .response_header_serializer import response_headers_serializer from .response_status_annotation import response_status +from .throttling import ThrottlingMisconfigurationException +from .throttling import RedisThrottlingConfiguration +from .throttling import set_redis_throttling_configuration from .throttling import throttling from .urls import register_url_regexp diff --git a/winter/web/throttling/__init__.py b/winter/web/throttling/__init__.py new file mode 100644 index 00000000..6677a8d8 --- /dev/null +++ b/winter/web/throttling/__init__.py @@ -0,0 +1,6 @@ +from .exceptions import ThrottlingMisconfigurationException +from .throttling import throttling +from .throttling import reset +from .throttling import create_throttle_class +from .redis_throttling_configuration import set_redis_throttling_configuration +from .redis_throttling_configuration import RedisThrottlingConfiguration diff --git a/winter/web/throttling/exceptions.py b/winter/web/throttling/exceptions.py new file mode 100644 index 00000000..66e83489 --- /dev/null +++ b/winter/web/throttling/exceptions.py @@ -0,0 +1,2 @@ +class ThrottlingMisconfigurationException(Exception): + pass diff --git a/winter/web/throttling/redis_throttling_client.py b/winter/web/throttling/redis_throttling_client.py new file mode 100644 index 00000000..a99122ea --- /dev/null +++ b/winter/web/throttling/redis_throttling_client.py @@ -0,0 +1,67 @@ +import time + +from redis import Redis + +from .exceptions import ThrottlingMisconfigurationException +from .redis_throttling_configuration import get_redis_throttling_configuration +from .redis_throttling_configuration import RedisThrottlingConfiguration + + +class RedisThrottlingClient: + # Redis Lua scripts are atomic + # Sliding window throttling. + # Rejected requests aren't counted. + THROTTLING_LUA = ''' + local key = KEYS[1] + local now = tonumber(ARGV[1]) + local duration = tonumber(ARGV[2]) + local max_requests = tonumber(ARGV[3]) + + redis.call("ZREMRANGEBYSCORE", key, 0, now - duration) + local count = redis.call("ZCARD", key) + + if count >= max_requests then + return 0 + end + + redis.call("ZADD", key, now, now) + redis.call("EXPIRE", key, duration) + return 1 + ''' + + def __init__(self, configuration: RedisThrottlingConfiguration): + self._redis_client = Redis( + host=configuration.host, + port=configuration.port, + db=configuration.db, + password=configuration.password, + decode_responses=True, + ) + self._throttling_script = self._redis_client.register_script(self.THROTTLING_LUA) + + def is_request_allowed(self, key: str, duration: int, num_requests: int) -> bool: + now = time.time() + is_allowed = self._throttling_script( + keys=[key], + args=[now, duration, num_requests] + ) + return is_allowed == 1 + + def delete(self, key: str): + self._redis_client.delete(key) + + +_redis_throttling_client: RedisThrottlingClient | None = None + +def get_redis_throttling_client() -> RedisThrottlingClient: + global _redis_throttling_client + + if _redis_throttling_client is None: + configuration = get_redis_throttling_configuration() + + if configuration is None: + raise ThrottlingMisconfigurationException('Configuration for Redis must be set before using the throttling') + + _redis_throttling_client = RedisThrottlingClient(configuration) + + return _redis_throttling_client diff --git a/winter/web/throttling/redis_throttling_configuration.py b/winter/web/throttling/redis_throttling_configuration.py new file mode 100644 index 00000000..410c3d63 --- /dev/null +++ b/winter/web/throttling/redis_throttling_configuration.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass + +from .exceptions import ThrottlingMisconfigurationException + + +@dataclass +class RedisThrottlingConfiguration: + host: str + port: int + db: int + password: str | None = None + + +_redis_throttling_configuration: RedisThrottlingConfiguration | None = None + + +def set_redis_throttling_configuration(configuration: RedisThrottlingConfiguration): + global _redis_throttling_configuration + if _redis_throttling_configuration is not None: + raise ThrottlingMisconfigurationException(f'{RedisThrottlingConfiguration.__name__} is already initialized') + _redis_throttling_configuration = configuration + + +def get_redis_throttling_configuration() -> RedisThrottlingConfiguration | None: + return _redis_throttling_configuration diff --git a/winter/web/throttling.py b/winter/web/throttling/throttling.py similarity index 82% rename from winter/web/throttling.py rename to winter/web/throttling/throttling.py index caf051e3..0b0fd6d7 100644 --- a/winter/web/throttling.py +++ b/winter/web/throttling/throttling.py @@ -5,12 +5,12 @@ from typing import Tuple import django.http -from django.core.cache import cache as default_cache from winter.core import annotate_method +from .redis_throttling_client import get_redis_throttling_client if TYPE_CHECKING: - from .routing import Route # noqa: F401 + from winter.web.routing import Route # noqa: F401 @dataclasses.dataclass @@ -33,23 +33,13 @@ def throttling(rate: Optional[str], scope: Optional[str] = None): class BaseRateThrottle: def __init__(self, throttling_: Throttling): self._throttling = throttling_ + self._redis_client = get_redis_throttling_client() def allow_request(self, request: django.http.HttpRequest) -> bool: ident = _get_ident(request) key = _get_cache_key(self._throttling.scope, ident) - history = default_cache.get(key, []) - now = time.time() - - while history and history[-1] <= now - self._throttling.duration: - history.pop() - - if len(history) >= self._throttling.num_requests: - return False - - history.insert(0, now) - default_cache.set(key, history, self._throttling.duration) - return True + return self._redis_client.is_request_allowed(key, self._throttling.duration, self._throttling.num_requests) def reset(request: django.http.HttpRequest, scope: str): @@ -59,7 +49,8 @@ def reset(request: django.http.HttpRequest, scope: str): """ ident = _get_ident(request) key = _get_cache_key(scope, ident) - default_cache.delete(key) + redis_client = get_redis_throttling_client() + redis_client.delete(key) CACHE_KEY_FORMAT = 'throttle_{scope}_{ident}'