diff --git a/ably/realtime/connectionmanager.py b/ably/realtime/connectionmanager.py index e2df3074..79f89f28 100644 --- a/ably/realtime/connectionmanager.py +++ b/ably/realtime/connectionmanager.py @@ -154,6 +154,9 @@ async def __get_transport_params(self) -> dict: params["v"] = protocol_version if self.connection_details: params["resume"] = self.connection_details.connection_key + # RTN2a: Set format to msgpack if use_binary_protocol is enabled + if self.options.use_binary_protocol: + params["format"] = "msgpack" return params async def close_impl(self) -> None: diff --git a/ably/realtime/realtime_channel.py b/ably/realtime/realtime_channel.py index 51ffc8a1..7c6ce6de 100644 --- a/ably/realtime/realtime_channel.py +++ b/ably/realtime/realtime_channel.py @@ -558,7 +558,8 @@ def _on_message(self, proto_msg: dict) -> None: elif action == ProtocolMessageAction.MESSAGE: messages = [] try: - messages = Message.from_encoded_array(proto_msg.get('messages'), context=self.__decoding_context) + messages = Message.from_encoded_array(proto_msg.get('messages'), + cipher=self.cipher, context=self.__decoding_context) self.__decoding_context.last_message_id = messages[-1].id self.__channel_serial = channel_serial except AblyException as e: diff --git a/ably/transport/websockettransport.py b/ably/transport/websockettransport.py index e1b93b09..450cd364 100644 --- a/ably/transport/websockettransport.py +++ b/ably/transport/websockettransport.py @@ -8,6 +8,8 @@ from enum import IntEnum from typing import TYPE_CHECKING +import msgpack + from ably.http.httputils import HttpUtils from ably.types.connectiondetails import ConnectionDetails from ably.util.eventemitter import EventEmitter @@ -71,6 +73,7 @@ def __init__(self, connection_manager: ConnectionManager, host: str, params: dic self.is_disposed = False self.host = host self.params = params + self.format = params.get('format', 'json') super().__init__() def connect(self): @@ -189,12 +192,23 @@ async def ws_read_loop(self): raise AblyException('ws_read_loop started with no websocket', 500, 50000) try: async for raw in self.websocket: - msg = json.loads(raw) - task = asyncio.create_task(self.on_protocol_message(msg)) - task.add_done_callback(self.on_protcol_message_handled) + # Decode based on format + try: + msg = self.decode_raw_websocket_frame(raw) + task = asyncio.create_task(self.on_protocol_message(msg)) + task.add_done_callback(self.on_protcol_message_handled) + except Exception as e: + log.exception( + f"WebSocketTransport.decode(): Unexpected exception handling channel message: {e}" + ) except ConnectionClosedOK: return + def decode_raw_websocket_frame(self, raw: str | bytes) -> dict: + if self.format == 'msgpack': + return msgpack.unpackb(raw, raw=False) + return json.loads(raw) + def on_protcol_message_handled(self, task): try: exception = task.exception() @@ -231,8 +245,13 @@ async def close(self): async def send(self, message: dict): if self.websocket is None: raise Exception() - raw_msg = json.dumps(message) - log.info(f'WebSocketTransport.send(): sending {raw_msg}') + # Encode based on format + if self.format == 'msgpack': + raw_msg = msgpack.packb(message, use_bin_type=True) + log.info(f'WebSocketTransport.send(): sending msgpack message (length: {len(raw_msg)} bytes)') + else: + raw_msg = json.dumps(message) + log.info(f'WebSocketTransport.send(): sending {raw_msg}') await self.websocket.send(raw_msg) def set_idle_timer(self, timeout: float): diff --git a/test/ably/realtime/realtimechannel_publish_test.py b/test/ably/realtime/realtimechannel_publish_test.py index fb940f35..7c32c1e2 100644 --- a/test/ably/realtime/realtimechannel_publish_test.py +++ b/test/ably/realtime/realtimechannel_publish_test.py @@ -3,25 +3,28 @@ import pytest from ably.realtime.connection import ConnectionState -from ably.realtime.realtime_channel import ChannelState +from ably.realtime.realtime_channel import ChannelOptions, ChannelState from ably.transport.websockettransport import ProtocolMessageAction from ably.types.message import Message +from ably.util.crypto import CipherParams from ably.util.exceptions import AblyException, IncompatibleClientIdException from test.ably.testapp import TestApp from test.ably.utils import BaseAsyncTestCase, WaitableEvent, assert_waiter +@pytest.mark.parametrize("transport", ["json", "msgpack"], ids=["JSON", "MsgPack"]) class TestRealtimeChannelPublish(BaseAsyncTestCase): """Tests for RTN7 spec - Message acknowledgment""" @pytest.fixture(autouse=True) - async def setup(self): + async def setup(self, transport): self.test_vars = await TestApp.get_test_vars() + self.use_binary_protocol = True if transport == 'msgpack' else False # RTN7a - Basic ACK/NACK functionality async def test_publish_returns_ack_on_success(self): """RTN7a: Verify that publish awaits ACK from server""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_ack_channel') @@ -34,7 +37,7 @@ async def test_publish_returns_ack_on_success(self): async def test_publish_raises_on_nack(self): """RTN7a: Verify that publish raises exception when NACK is received""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_nack_channel') @@ -76,7 +79,7 @@ async def send_and_nack(message): # RTN7b - msgSerial incrementing async def test_msgserial_increments_sequentially(self): """RTN7b: Verify that msgSerial increments for each message""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_msgserial_channel') @@ -108,7 +111,7 @@ async def capture_serial(message): # RTN7e - Fail pending messages on SUSPENDED, CLOSED, FAILED async def test_pending_messages_fail_on_suspended(self): """RTN7e: Verify pending messages fail when connection enters SUSPENDED state""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_suspended_channel') @@ -153,7 +156,7 @@ async def check_pending(): async def test_pending_messages_fail_on_failed(self): """RTN7e: Verify pending messages fail when connection enters FAILED state""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_failed_channel') @@ -195,7 +198,7 @@ async def check_pending(): async def test_fail_on_disconnected_when_queue_messages_false(self): """RTN7d: Verify pending messages fail on DISCONNECTED if queueMessages is false""" # Create client with queueMessages=False - ably = await TestApp.get_ably_realtime(queue_messages=False) + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol, queue_messages=False) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_disconnected_channel') @@ -236,7 +239,7 @@ async def check_pending(): async def test_queue_on_disconnected_when_queue_messages_true(self): """RTN7d: Verify messages are queued (not failed) on DISCONNECTED when queueMessages is true""" # Create client with queueMessages=True (default) - ably = await TestApp.get_ably_realtime(queue_messages=True) + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol, queue_messages=True) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_queue_channel') @@ -285,7 +288,7 @@ async def check_disconnected(): # RTN19a2 - Reset msgSerial on new connectionId async def test_msgserial_resets_on_new_connection_id(self): """RTN19a2: Verify msgSerial resets to 0 when connectionId changes""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_reset_serial_channel') @@ -322,7 +325,7 @@ async def test_msgserial_resets_on_new_connection_id(self): async def test_msgserial_not_reset_on_same_connection_id(self): """RTN19a2: Verify msgSerial is NOT reset when connectionId stays the same""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_same_connection_channel') @@ -360,7 +363,7 @@ async def test_msgserial_not_reset_on_same_connection_id(self): # Test that multiple messages get correct msgSerial values async def test_multiple_messages_concurrent(self): """RTN7b: Test that multiple concurrent publishes get sequential msgSerials""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_concurrent_channel') @@ -383,7 +386,7 @@ async def test_multiple_messages_concurrent(self): # RTN19a - Resend messages awaiting ACK on reconnect async def test_pending_messages_resent_on_reconnect(self): """RTN19a: Verify messages awaiting ACK are resent when transport reconnects""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_resend_channel') @@ -437,7 +440,7 @@ async def check_pending(): async def test_msgserial_preserved_on_resume(self): """RTN19a2: Verify msgSerial counter is preserved when resuming (same connectionId)""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_preserve_serial_channel') @@ -488,7 +491,7 @@ async def check_pending(): async def test_msgserial_reset_on_failed_resume(self): """RTN19a2: Verify msgSerial counter is reset when resume fails (new connectionId)""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_reset_serial_resume_channel') @@ -540,7 +543,7 @@ async def check_pending(): # Test ACK with count > 1 async def test_ack_with_multiple_count(self): """RTN7a/RTN7b: Test that ACK with count > 1 completes multiple messages""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_multi_ack_channel') @@ -589,7 +592,7 @@ async def check_pending(): async def test_queued_messages_sent_before_channel_reattach(self): """RTL3d + RTL6c2: Verify queued messages are sent immediately on reconnection, without waiting for channel reattachment to complete""" - ably = await TestApp.get_ably_realtime(queue_messages=True) + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol, queue_messages=True) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_rtl3d_rtl6c2_channel') @@ -681,7 +684,7 @@ async def check_sent_queued_messages(): # RSL1i - Message size limit tests async def test_publish_message_exceeding_size_limit(self): """RSL1i: Verify that publishing a message exceeding the size limit raises an exception""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_size_limit_channel') @@ -702,7 +705,7 @@ async def test_publish_message_exceeding_size_limit(self): async def test_publish_message_within_size_limit(self): """RSL1i: Verify that publishing a message within the size limit succeeds""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_size_ok_channel') @@ -720,7 +723,9 @@ async def test_publish_message_within_size_limit(self): # RTL6g - Client ID validation tests async def test_publish_with_matching_client_id(self): """RTL6g2: Verify that publishing with explicit matching clientId succeeds""" - ably = await TestApp.get_ably_realtime(client_id='test_client_123') + ably = await TestApp.get_ably_realtime( + use_binary_protocol=self.use_binary_protocol, client_id='test_client_123' + ) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_client_id_channel') @@ -736,7 +741,9 @@ async def test_publish_with_matching_client_id(self): async def test_publish_with_null_client_id_when_identified(self): """RTL6g1: Verify that publishing with null clientId gets populated by server when client is identified""" - ably = await TestApp.get_ably_realtime(client_id='test_client_456') + ably = await TestApp.get_ably_realtime( + use_binary_protocol=self.use_binary_protocol, client_id='test_client_456' + ) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_null_client_id_channel') @@ -749,7 +756,9 @@ async def test_publish_with_null_client_id_when_identified(self): async def test_publish_with_mismatched_client_id_fails(self): """RTL6g3: Verify that publishing with mismatched clientId is rejected""" - ably = await TestApp.get_ably_realtime(client_id='test_client_789') + ably = await TestApp.get_ably_realtime( + use_binary_protocol=self.use_binary_protocol, client_id='test_client_789' + ) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_mismatch_client_id_channel') @@ -769,7 +778,9 @@ async def test_publish_with_mismatched_client_id_fails(self): async def test_publish_with_wildcard_client_id_fails(self): """RTL6g3: Verify that publishing with wildcard clientId is rejected""" - ably = await TestApp.get_ably_realtime(client_id='test_client_wildcard') + ably = await TestApp.get_ably_realtime( + use_binary_protocol=self.use_binary_protocol, client_id='test_client_wildcard' + ) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_wildcard_client_id_channel') @@ -790,7 +801,7 @@ async def test_publish_with_wildcard_client_id_fails(self): # RTL6i - Data type variation tests async def test_publish_with_string_data(self): """RTL6i: Verify that publishing with string data succeeds""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_string_data_channel') @@ -803,7 +814,7 @@ async def test_publish_with_string_data(self): async def test_publish_with_json_object_data(self): """RTL6i: Verify that publishing with JSON object data succeeds""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_json_object_channel') @@ -822,7 +833,7 @@ async def test_publish_with_json_object_data(self): async def test_publish_with_json_array_data(self): """RTL6i: Verify that publishing with JSON array data succeeds""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_json_array_channel') @@ -836,7 +847,7 @@ async def test_publish_with_json_array_data(self): async def test_publish_with_null_data(self): """RTL6i3: Verify that publishing with null data succeeds""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_null_data_channel') @@ -849,7 +860,7 @@ async def test_publish_with_null_data(self): async def test_publish_with_null_name(self): """RTL6i3: Verify that publishing with null name succeeds""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_null_name_channel') @@ -862,7 +873,7 @@ async def test_publish_with_null_name(self): async def test_publish_message_array(self): """RTL6i2: Verify that publishing an array of messages succeeds""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_message_array_channel') @@ -881,7 +892,7 @@ async def test_publish_message_array(self): # RTL6c4 - Channel state validation tests async def test_publish_fails_on_suspended_channel(self): """RTL6c4: Verify that publishing on a SUSPENDED channel fails""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_suspended_channel') @@ -904,7 +915,7 @@ async def test_publish_fails_on_suspended_channel(self): async def test_publish_fails_on_failed_channel(self): """RTL6c4: Verify that publishing on a FAILED channel fails""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) channel = ably.channels.get('test_failed_channel') @@ -928,10 +939,10 @@ async def test_publish_fails_on_failed_channel(self): # RSL1k - Idempotent publishing test async def test_idempotent_realtime_publishing(self): """RSL1k2, RSL1k5: Verify that messages with explicit IDs can be published for idempotent behavior""" - ably = await TestApp.get_ably_realtime() + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) - channel = ably.channels.get('test_idempotent_channel') + channel = ably.channels.get(f'test_idempotent_channel_{self.use_binary_protocol}') await channel.attach() idempotent_id = 'test-msg-id-12345' @@ -975,3 +986,36 @@ def on_message(message): assert data_received[1] == 'third message' await ably.close() + + async def test_publish_with_encryption(self): + """Verify that encrypted messages can be published and received correctly""" + # Create connection with binary protocol enabled + ably = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) + await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) + + # Get channel with encryption enabled + cipher_params = CipherParams(secret_key=b'0123456789abcdef0123456789abcdef') + channel_options = ChannelOptions(cipher=cipher_params) + channel = ably.channels.get('encrypted_channel', channel_options) + await channel.attach() + + received_data = None + data_received = WaitableEvent() + def on_message(message): + nonlocal received_data + try: + received_data = message.data + data_received.finish() + except Exception as e: + data_received.finish() + raise e + + await channel.subscribe(on_message) + + await channel.publish('encrypted_event', 'sensitive data') + + await data_received.wait() + + assert received_data == 'sensitive data' + + await ably.close() diff --git a/test/ably/realtime/realtimeconnection_test.py b/test/ably/realtime/realtimeconnection_test.py index deab3263..68ffb6dd 100644 --- a/test/ably/realtime/realtimeconnection_test.py +++ b/test/ably/realtime/realtimeconnection_test.py @@ -400,3 +400,63 @@ async def on_protocol_message(msg): await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) await ably.close() + + # RTN2f - Test msgpack format parameter when use_binary_protocol is enabled + async def test_connection_format_msgpack_with_binary_protocol(self): + """Test that format=msgpack is sent when use_binary_protocol=True""" + ably = await TestApp.get_ably_realtime(use_binary_protocol=True) + await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) + + received_raw_websocket_frames = [] + transport = ably.connection.connection_manager.transport + original_decode_raw_websocket_frame = transport.decode_raw_websocket_frame + + def intercepted_websocket_frame(data): + received_raw_websocket_frames.append(data) + return original_decode_raw_websocket_frame(data) + + transport.decode_raw_websocket_frame = intercepted_websocket_frame + + # Verify transport has format set to msgpack + assert ably.connection.connection_manager.transport is not None + assert ably.connection.connection_manager.transport.format == 'msgpack' + + # Verify params include format=msgpack + assert ably.connection.connection_manager.transport.params.get('format') == 'msgpack' + + await ably.channels.get('connection_test').publish('test', b'test') + + assert len(received_raw_websocket_frames) > 0 + assert all(isinstance(frame, bytes) for frame in received_raw_websocket_frames) + + await ably.close() + + async def test_connection_format_json_without_binary_protocol(self): + """Test that format defaults to json when use_binary_protocol=False""" + ably = await TestApp.get_ably_realtime(use_binary_protocol=False) + await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) + + received_raw_websocket_frames = [] + transport = ably.connection.connection_manager.transport + original_decode_raw_websocket_frame = transport.decode_raw_websocket_frame + + def intercepted_websocket_frame(data): + received_raw_websocket_frames.append(data) + return original_decode_raw_websocket_frame(data) + + transport.decode_raw_websocket_frame = intercepted_websocket_frame + + # Verify transport has format set to json (default) + assert ably.connection.connection_manager.transport is not None + assert ably.connection.connection_manager.transport.format == 'json' + + await ably.channels.get('connection_test').publish('test', b'test') + + # Verify params don't include format parameter (or it's json) + transport_format = ably.connection.connection_manager.transport.params.get('format') + assert transport_format is None or transport_format == 'json' + + assert len(received_raw_websocket_frames) > 0 + assert all(isinstance(frame, str) for frame in received_raw_websocket_frames) + + await ably.close()