diff --git a/faststream_redis_timers/broker.py b/faststream_redis_timers/broker.py index 66ee9da..c081e2e 100644 --- a/faststream_redis_timers/broker.py +++ b/faststream_redis_timers/broker.py @@ -19,7 +19,7 @@ from faststream.specification.schema import BrokerSpec from faststream.specification.schema.extra import Tag, TagDict -from faststream_redis_timers.configs import ConnectionState, TimersBrokerConfig +from faststream_redis_timers.configs import ConnectionState, RedisClient, TimersBrokerConfig from faststream_redis_timers.message import TimerMessage from faststream_redis_timers.publisher.producer import TimersProducer from faststream_redis_timers.publisher.usecase import TimersPublisher @@ -30,7 +30,6 @@ if typing.TYPE_CHECKING: from faststream._internal.context.repository import ContextRepo - from redis.asyncio import Redis class TimersParamsStorage(DefaultLoggerStorage): @@ -61,7 +60,7 @@ class TimersBroker( TimersRegistrator, BrokerUsecase[ TimerMessage, - "Redis[bytes]", + RedisClient, BrokerConfig, # Use BrokerConfig to avoid typing issues when passing to FastStream app ], ): @@ -80,7 +79,7 @@ class TimersBroker( def __init__( # noqa: PLR0913 self, - client: "Redis[bytes] | None" = None, + client: "RedisClient | None" = None, *, timeline_key: str = "timers_timeline", payloads_key: str = "timers_payloads", @@ -137,7 +136,7 @@ def __init__( # noqa: PLR0913 super().__init__(config=broker_config, specification=specification, routers=routers) # ty: ignore[unknown-argument] @typing.override - async def _connect(self) -> "Redis[bytes]": + async def _connect(self) -> "RedisClient": return self.config.broker_config.connection.client @typing.override @@ -214,7 +213,9 @@ async def get_pending_timers(self, topic: str, before: datetime | None = None) - """ client = self.config.broker_config.connection.client score_max: str | float = before.timestamp() if before is not None else "+inf" - raw_ids: list[bytes] = await client.zrangebyscore(self._topic_timeline_key(topic), "-inf", score_max) + raw_ids: list[bytes] | list[str] = await client.zrangebyscore( + self._topic_timeline_key(topic), "-inf", score_max + ) return [r.decode() if isinstance(r, bytes) else r for r in raw_ids] async def cancel_all(self, topic: str) -> int: diff --git a/faststream_redis_timers/configs.py b/faststream_redis_timers/configs.py index 3716af7..09eb771 100644 --- a/faststream_redis_timers/configs.py +++ b/faststream_redis_timers/configs.py @@ -9,18 +9,24 @@ from redis.asyncio import Redis +# Accepts a client created with either default (bytes) or ``decode_responses=True`` (str). +# The CLAIM Lua reply is forced through ``NEVER_DECODE`` so the binary envelope stays +# intact regardless of which mode the user picked. +type RedisClient = "Redis[bytes] | Redis[str]" + + class ConnectionState: - def __init__(self, client: "Redis[bytes] | None" = None) -> None: + def __init__(self, client: "RedisClient | None" = None) -> None: self._client = client @property - def client(self) -> "Redis[bytes]": + def client(self) -> "RedisClient": if self._client is None: msg = "Connection not available. Connect the broker first." raise IncorrectState(msg) return self._client - async def connect(self) -> "Redis[bytes]": + async def connect(self) -> "RedisClient": return self.client async def disconnect(self) -> None: diff --git a/faststream_redis_timers/message.py b/faststream_redis_timers/message.py index a57f425..e0e1015 100644 --- a/faststream_redis_timers/message.py +++ b/faststream_redis_timers/message.py @@ -7,7 +7,7 @@ if typing.TYPE_CHECKING: - from redis.asyncio import Redis + from faststream_redis_timers.configs import RedisClient class TimerMessage(TypedDict): @@ -30,7 +30,7 @@ class TimerStreamMessage(StreamMessage["TimerMessage"]): def __init__( self, *args: typing.Any, - client: "Redis[bytes] | None" = None, + client: "RedisClient | None" = None, timeline_key: str = "", payloads_key: str = "", timer_id: str = "", diff --git a/faststream_redis_timers/publisher/usecase.py b/faststream_redis_timers/publisher/usecase.py index 53fea87..f2a05a3 100644 --- a/faststream_redis_timers/publisher/usecase.py +++ b/faststream_redis_timers/publisher/usecase.py @@ -103,7 +103,7 @@ async def fetch_redis_timers(self, dt: datetime) -> list[tuple[str, str]]: """Return (topic, timer_id) pairs for timers due by *dt* on this publisher's topic.""" client = self.config._outer_config.connection.client # noqa: SLF001 timeline_key = f"{self.config._outer_config.timeline_key}:{self.config.full_topic}" # noqa: SLF001 - timer_ids: list[bytes] = await client.zrangebyscore(timeline_key, "-inf", dt.timestamp()) + timer_ids: list[bytes] | list[str] = await client.zrangebyscore(timeline_key, "-inf", dt.timestamp()) return [(self.config.topic, raw_id.decode() if isinstance(raw_id, bytes) else raw_id) for raw_id in timer_ids] async def request(self, *args: typing.Any, **kwargs: typing.Any) -> typing.NoReturn: diff --git a/faststream_redis_timers/subscriber/lua.py b/faststream_redis_timers/subscriber/lua.py index 5ce3347..8c8fc84 100644 --- a/faststream_redis_timers/subscriber/lua.py +++ b/faststream_redis_timers/subscriber/lua.py @@ -16,11 +16,12 @@ import hashlib import typing +from redis.client import NEVER_DECODE from redis.exceptions import NoScriptError if typing.TYPE_CHECKING: - from redis.asyncio import Redis + from faststream_redis_timers.configs import RedisClient CLAIM_LUA = """\ @@ -45,15 +46,23 @@ async def eval_cached( - client: "Redis[bytes]", + client: "RedisClient", script: str, sha: str, num_keys: int, *args: typing.Any, ) -> typing.Any: - """Run a script via EVALSHA, falling back to SCRIPT LOAD + EVALSHA on NOSCRIPT.""" + """ + Run a script via EVALSHA, falling back to SCRIPT LOAD + EVALSHA on NOSCRIPT. + + Uses ``NEVER_DECODE`` so the script's reply is returned as raw bytes even when the + Redis client was constructed with ``decode_responses=True``: the timer payload is + a binary envelope (BinaryMessageFormatV1) and forcing UTF-8 decoding on it would + fail at the first non-ASCII byte. + """ + options = {NEVER_DECODE: []} try: - return await client.evalsha(sha, num_keys, *args) + return await client.execute_command("EVALSHA", sha, num_keys, *args, **options) except NoScriptError: await client.script_load(script) - return await client.evalsha(sha, num_keys, *args) + return await client.execute_command("EVALSHA", sha, num_keys, *args, **options) diff --git a/faststream_redis_timers/subscriber/usecase.py b/faststream_redis_timers/subscriber/usecase.py index 14c863f..1ac7e42 100644 --- a/faststream_redis_timers/subscriber/usecase.py +++ b/faststream_redis_timers/subscriber/usecase.py @@ -26,9 +26,8 @@ from faststream._internal.endpoint.publisher import PublisherProto from faststream._internal.endpoint.subscriber.call_item import CallsCollection from faststream.message import StreamMessage - from redis.asyncio import Redis - from faststream_redis_timers.configs import TimersBrokerConfig + from faststream_redis_timers.configs import RedisClient, TimersBrokerConfig class TimersSubscriberSpecification(SubscriberSpecification["TimersBrokerConfig", TimersSubscriberSpecificationConfig]): @@ -69,7 +68,7 @@ def __init__( self._config = config @property - def _client(self) -> "Redis[bytes]": + def _client(self) -> "RedisClient": return self._outer_config.connection.client @typing.override @@ -85,7 +84,7 @@ async def start(self) -> None: else: start_signal.set() - async def _consume(self, client: "Redis[bytes]", *, start_signal: anyio.Event) -> None: + async def _consume(self, client: "RedisClient", *, start_signal: anyio.Event) -> None: with suppress(Exception): if await client.ping(): start_signal.set() @@ -122,7 +121,7 @@ async def _consume(self, client: "Redis[bytes]", *, start_signal: anyio.Event) - async def _get_msgs( self, - client: "Redis[bytes]", + client: "RedisClient", tg: "TaskGroup", limiter: anyio.CapacityLimiter, ) -> int: @@ -133,7 +132,7 @@ async def _get_msgs( return -1 now = time.time() - timer_ids: list[bytes] = await client.zrangebyscore( + timer_ids: list[bytes] | list[str] = await client.zrangebyscore( self._config.topic_timeline_key, "-inf", now, start=0, num=free ) if not timer_ids: @@ -150,7 +149,7 @@ async def _claim_and_consume( raw_id: bytes | str, lease_ttl: int, limiter: anyio.CapacityLimiter, - client: "Redis[bytes]", + client: "RedisClient", ) -> None: try: async with limiter: diff --git a/tests/test_delivery.py b/tests/test_delivery.py index 9369e87..446ea50 100644 --- a/tests/test_delivery.py +++ b/tests/test_delivery.py @@ -2,6 +2,7 @@ from datetime import timedelta from faststream.response.publish_type import PublishType +from pydantic import BaseModel from faststream_redis_timers import TimersBroker from faststream_redis_timers.response import TimerPublishCommand @@ -39,6 +40,33 @@ async def handler(body: dict) -> None: assert received == [{"order_id": 42, "status": "due"}] +async def test_subscriber_receives_pydantic_model(broker: TimersBroker) -> None: + class ChatMessage(BaseModel): + chat_id: str + message_id: int + message_text: str + + received: list[ChatMessage] = [] + event = asyncio.Event() + + @broker.subscriber("topic") + async def handler(body: ChatMessage) -> None: + received.append(body) + event.set() + + payload = { + "chat_id": "019ac5ac-0c30-7341-9667-86a1cc343f86", + "message_id": 3056, + "message_text": "Ок", + } + async with broker: + await broker.publish(payload, topic="topic") + await asyncio.wait_for(event.wait(), timeout=5.0) + + assert received == [ChatMessage(**payload)] + assert isinstance(received[0], ChatMessage) + + async def test_publisher_sends_message(broker: TimersBroker) -> None: received: list[str] = [] event = asyncio.Event() diff --git a/tests/test_envelope.py b/tests/test_envelope.py index adb5594..5746751 100644 --- a/tests/test_envelope.py +++ b/tests/test_envelope.py @@ -1,5 +1,6 @@ import asyncio import json +import os import time import uuid @@ -10,6 +11,9 @@ from faststream_redis_timers.envelope import TimerMessageFormat +REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0") + + async def test_correlation_id_propagates(broker: TimersBroker) -> None: seen: list[tuple[str, str]] = [] event = asyncio.Event() @@ -100,6 +104,41 @@ def test_envelope_size_smaller_than_legacy() -> None: assert len(new) < len(body) + 200 +async def test_works_with_decode_responses_true() -> None: + """A Redis client created with decode_responses=True must not break payload parsing.""" + client = Redis.from_url(REDIS_URL, decode_responses=True) + try: + await client.ping() + except Exception: # noqa: BLE001 # pragma: no cover + await client.aclose() # ty: ignore[unresolved-attribute] + return + + suffix = uuid.uuid4().hex + broker = TimersBroker( + client, + timeline_key=f"decstr_tl_{suffix}", + payloads_key=f"decstr_pl_{suffix}", + ) + + seen: list[dict] = [] + event = asyncio.Event() + + @broker.subscriber("topic") + async def handler(body: dict) -> None: + seen.append(body) + event.set() + + payload = {"chat_id": "abc", "message_text": "Ок", "message_id": 3056} + try: + async with broker: + await broker.publish(payload, topic="topic", timer_id="3056") + await asyncio.wait_for(event.wait(), timeout=5.0) + finally: + await client.aclose() # ty: ignore[unresolved-attribute] + + assert seen == [payload] + + async def test_legacy_envelope_still_parses(redis_client: Redis) -> None: """A v0.x JSON-of-hex payload sitting in Redis is still delivered after upgrade.""" suffix = uuid.uuid4().hex diff --git a/tests/test_unit.py b/tests/test_unit.py index 3d4f0c6..6c7826a 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -298,13 +298,13 @@ async def handler(body: str) -> None: ... async def test_eval_cached_falls_back_on_noscript() -> None: client = AsyncMock() - client.evalsha.side_effect = [NoScriptError("NOSCRIPT"), b"ok"] + client.execute_command.side_effect = [NoScriptError("NOSCRIPT"), b"ok"] client.script_load.return_value = "abc123" result = await eval_cached(client, "return 1", "abc123", 0) assert result == b"ok" - assert client.evalsha.await_count == 2 + assert client.execute_command.await_count == 2 client.script_load.assert_awaited_once_with("return 1")