Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions faststream_redis_timers/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from faststream.specification.schema import BrokerSpec
from faststream.specification.schema.extra import Tag, TagDict

from faststream_redis_timers.configs import ConnectionState, TimersBrokerConfig
from faststream_redis_timers.configs import ConnectionState, RedisClient, TimersBrokerConfig
from faststream_redis_timers.message import TimerMessage
from faststream_redis_timers.publisher.producer import TimersProducer
from faststream_redis_timers.publisher.usecase import TimersPublisher
Expand All @@ -30,7 +30,6 @@

if typing.TYPE_CHECKING:
from faststream._internal.context.repository import ContextRepo
from redis.asyncio import Redis


class TimersParamsStorage(DefaultLoggerStorage):
Expand Down Expand Up @@ -61,7 +60,7 @@ class TimersBroker(
TimersRegistrator,
BrokerUsecase[
TimerMessage,
"Redis[bytes]",
RedisClient,
BrokerConfig, # Use BrokerConfig to avoid typing issues when passing to FastStream app
],
):
Expand All @@ -80,7 +79,7 @@ class TimersBroker(

def __init__( # noqa: PLR0913
self,
client: "Redis[bytes] | None" = None,
client: "RedisClient | None" = None,
*,
timeline_key: str = "timers_timeline",
payloads_key: str = "timers_payloads",
Expand Down Expand Up @@ -137,7 +136,7 @@ def __init__( # noqa: PLR0913
super().__init__(config=broker_config, specification=specification, routers=routers) # ty: ignore[unknown-argument]

@typing.override
async def _connect(self) -> "Redis[bytes]":
async def _connect(self) -> "RedisClient":
return self.config.broker_config.connection.client

@typing.override
Expand Down Expand Up @@ -214,7 +213,9 @@ async def get_pending_timers(self, topic: str, before: datetime | None = None) -
"""
client = self.config.broker_config.connection.client
score_max: str | float = before.timestamp() if before is not None else "+inf"
raw_ids: list[bytes] = await client.zrangebyscore(self._topic_timeline_key(topic), "-inf", score_max)
raw_ids: list[bytes] | list[str] = await client.zrangebyscore(
self._topic_timeline_key(topic), "-inf", score_max
)
return [r.decode() if isinstance(r, bytes) else r for r in raw_ids]

async def cancel_all(self, topic: str) -> int:
Expand Down
12 changes: 9 additions & 3 deletions faststream_redis_timers/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,24 @@
from redis.asyncio import Redis


# Accepts a client created with either default (bytes) or ``decode_responses=True`` (str).
# The CLAIM Lua reply is forced through ``NEVER_DECODE`` so the binary envelope stays
# intact regardless of which mode the user picked.
type RedisClient = "Redis[bytes] | Redis[str]"


class ConnectionState:
def __init__(self, client: "Redis[bytes] | None" = None) -> None:
def __init__(self, client: "RedisClient | None" = None) -> None:
self._client = client

@property
def client(self) -> "Redis[bytes]":
def client(self) -> "RedisClient":
if self._client is None:
msg = "Connection not available. Connect the broker first."
raise IncorrectState(msg)
return self._client

async def connect(self) -> "Redis[bytes]":
async def connect(self) -> "RedisClient":
return self.client

async def disconnect(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions faststream_redis_timers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


if typing.TYPE_CHECKING:
from redis.asyncio import Redis
from faststream_redis_timers.configs import RedisClient


class TimerMessage(TypedDict):
Expand All @@ -30,7 +30,7 @@ class TimerStreamMessage(StreamMessage["TimerMessage"]):
def __init__(
self,
*args: typing.Any,
client: "Redis[bytes] | None" = None,
client: "RedisClient | None" = None,
timeline_key: str = "",
payloads_key: str = "",
timer_id: str = "",
Expand Down
2 changes: 1 addition & 1 deletion faststream_redis_timers/publisher/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def fetch_redis_timers(self, dt: datetime) -> list[tuple[str, str]]:
"""Return (topic, timer_id) pairs for timers due by *dt* on this publisher's topic."""
client = self.config._outer_config.connection.client # noqa: SLF001
timeline_key = f"{self.config._outer_config.timeline_key}:{self.config.full_topic}" # noqa: SLF001
timer_ids: list[bytes] = await client.zrangebyscore(timeline_key, "-inf", dt.timestamp())
timer_ids: list[bytes] | list[str] = await client.zrangebyscore(timeline_key, "-inf", dt.timestamp())
return [(self.config.topic, raw_id.decode() if isinstance(raw_id, bytes) else raw_id) for raw_id in timer_ids]

async def request(self, *args: typing.Any, **kwargs: typing.Any) -> typing.NoReturn:
Expand Down
19 changes: 14 additions & 5 deletions faststream_redis_timers/subscriber/lua.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
import hashlib
import typing

from redis.client import NEVER_DECODE
from redis.exceptions import NoScriptError


if typing.TYPE_CHECKING:
from redis.asyncio import Redis
from faststream_redis_timers.configs import RedisClient


CLAIM_LUA = """\
Expand All @@ -45,15 +46,23 @@


async def eval_cached(
client: "Redis[bytes]",
client: "RedisClient",
script: str,
sha: str,
num_keys: int,
*args: typing.Any,
) -> typing.Any:
"""Run a script via EVALSHA, falling back to SCRIPT LOAD + EVALSHA on NOSCRIPT."""
"""
Run a script via EVALSHA, falling back to SCRIPT LOAD + EVALSHA on NOSCRIPT.

Uses ``NEVER_DECODE`` so the script's reply is returned as raw bytes even when the
Redis client was constructed with ``decode_responses=True``: the timer payload is
a binary envelope (BinaryMessageFormatV1) and forcing UTF-8 decoding on it would
fail at the first non-ASCII byte.
"""
options = {NEVER_DECODE: []}
try:
return await client.evalsha(sha, num_keys, *args)
return await client.execute_command("EVALSHA", sha, num_keys, *args, **options)
except NoScriptError:
await client.script_load(script)
return await client.evalsha(sha, num_keys, *args)
return await client.execute_command("EVALSHA", sha, num_keys, *args, **options)
13 changes: 6 additions & 7 deletions faststream_redis_timers/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
from faststream._internal.endpoint.publisher import PublisherProto
from faststream._internal.endpoint.subscriber.call_item import CallsCollection
from faststream.message import StreamMessage
from redis.asyncio import Redis

from faststream_redis_timers.configs import TimersBrokerConfig
from faststream_redis_timers.configs import RedisClient, TimersBrokerConfig


class TimersSubscriberSpecification(SubscriberSpecification["TimersBrokerConfig", TimersSubscriberSpecificationConfig]):
Expand Down Expand Up @@ -69,7 +68,7 @@ def __init__(
self._config = config

@property
def _client(self) -> "Redis[bytes]":
def _client(self) -> "RedisClient":
return self._outer_config.connection.client

@typing.override
Expand All @@ -85,7 +84,7 @@ async def start(self) -> None:
else:
start_signal.set()

async def _consume(self, client: "Redis[bytes]", *, start_signal: anyio.Event) -> None:
async def _consume(self, client: "RedisClient", *, start_signal: anyio.Event) -> None:
with suppress(Exception):
if await client.ping():
start_signal.set()
Expand Down Expand Up @@ -122,7 +121,7 @@ async def _consume(self, client: "Redis[bytes]", *, start_signal: anyio.Event) -

async def _get_msgs(
self,
client: "Redis[bytes]",
client: "RedisClient",
tg: "TaskGroup",
limiter: anyio.CapacityLimiter,
) -> int:
Expand All @@ -133,7 +132,7 @@ async def _get_msgs(
return -1

now = time.time()
timer_ids: list[bytes] = await client.zrangebyscore(
timer_ids: list[bytes] | list[str] = await client.zrangebyscore(
self._config.topic_timeline_key, "-inf", now, start=0, num=free
)
if not timer_ids:
Expand All @@ -150,7 +149,7 @@ async def _claim_and_consume(
raw_id: bytes | str,
lease_ttl: int,
limiter: anyio.CapacityLimiter,
client: "Redis[bytes]",
client: "RedisClient",
) -> None:
try:
async with limiter:
Expand Down
28 changes: 28 additions & 0 deletions tests/test_delivery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import timedelta

from faststream.response.publish_type import PublishType
from pydantic import BaseModel

from faststream_redis_timers import TimersBroker
from faststream_redis_timers.response import TimerPublishCommand
Expand Down Expand Up @@ -39,6 +40,33 @@ async def handler(body: dict) -> None:
assert received == [{"order_id": 42, "status": "due"}]


async def test_subscriber_receives_pydantic_model(broker: TimersBroker) -> None:
class ChatMessage(BaseModel):
chat_id: str
message_id: int
message_text: str

received: list[ChatMessage] = []
event = asyncio.Event()

@broker.subscriber("topic")
async def handler(body: ChatMessage) -> None:
received.append(body)
event.set()

payload = {
"chat_id": "019ac5ac-0c30-7341-9667-86a1cc343f86",
"message_id": 3056,
"message_text": "Ок",
}
async with broker:
await broker.publish(payload, topic="topic")
await asyncio.wait_for(event.wait(), timeout=5.0)

assert received == [ChatMessage(**payload)]
assert isinstance(received[0], ChatMessage)


async def test_publisher_sends_message(broker: TimersBroker) -> None:
received: list[str] = []
event = asyncio.Event()
Expand Down
39 changes: 39 additions & 0 deletions tests/test_envelope.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import json
import os
import time
import uuid

Expand All @@ -10,6 +11,9 @@
from faststream_redis_timers.envelope import TimerMessageFormat


REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")


async def test_correlation_id_propagates(broker: TimersBroker) -> None:
seen: list[tuple[str, str]] = []
event = asyncio.Event()
Expand Down Expand Up @@ -100,6 +104,41 @@ def test_envelope_size_smaller_than_legacy() -> None:
assert len(new) < len(body) + 200


async def test_works_with_decode_responses_true() -> None:
"""A Redis client created with decode_responses=True must not break payload parsing."""
client = Redis.from_url(REDIS_URL, decode_responses=True)
try:
await client.ping()
except Exception: # noqa: BLE001 # pragma: no cover
await client.aclose() # ty: ignore[unresolved-attribute]
return

suffix = uuid.uuid4().hex
broker = TimersBroker(
client,
timeline_key=f"decstr_tl_{suffix}",
payloads_key=f"decstr_pl_{suffix}",
)

seen: list[dict] = []
event = asyncio.Event()

@broker.subscriber("topic")
async def handler(body: dict) -> None:
seen.append(body)
event.set()

payload = {"chat_id": "abc", "message_text": "Ок", "message_id": 3056}
try:
async with broker:
await broker.publish(payload, topic="topic", timer_id="3056")
await asyncio.wait_for(event.wait(), timeout=5.0)
finally:
await client.aclose() # ty: ignore[unresolved-attribute]

assert seen == [payload]


async def test_legacy_envelope_still_parses(redis_client: Redis) -> None:
"""A v0.x JSON-of-hex payload sitting in Redis is still delivered after upgrade."""
suffix = uuid.uuid4().hex
Expand Down
4 changes: 2 additions & 2 deletions tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,13 @@ async def handler(body: str) -> None: ...

async def test_eval_cached_falls_back_on_noscript() -> None:
client = AsyncMock()
client.evalsha.side_effect = [NoScriptError("NOSCRIPT"), b"ok"]
client.execute_command.side_effect = [NoScriptError("NOSCRIPT"), b"ok"]
client.script_load.return_value = "abc123"

result = await eval_cached(client, "return 1", "abc123", 0)

assert result == b"ok"
assert client.evalsha.await_count == 2
assert client.execute_command.await_count == 2
client.script_load.assert_awaited_once_with("return 1")


Expand Down
Loading