From 56b8d932f59c08335e78513c200d3c8a5097ae12 Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Thu, 2 Apr 2026 09:45:27 +0000 Subject: [PATCH 01/18] feat: add _StreamMultiplexer for asyncio bidi-gRPC streams --- .../storage/asyncio/_stream_multiplexer.py | 198 +++++++ .../unit/asyncio/test_stream_multiplexer.py | 503 ++++++++++++++++++ 2 files changed, 701 insertions(+) create mode 100644 packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py create mode 100644 packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py diff --git a/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py new file mode 100644 index 000000000000..abd99e087cd3 --- /dev/null +++ b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py @@ -0,0 +1,198 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +from typing import Awaitable, Callable, Dict, Optional, Set + +from google.cloud import _storage_v2 +from google.cloud.storage.asyncio.async_read_object_stream import ( + _AsyncReadObjectStream, +) + +logger = logging.getLogger(__name__) + +_DEFAULT_QUEUE_MAX_SIZE = 100 +_DEFAULT_PUT_TIMEOUT = 20.0 + + +class _StreamError: + """Wraps an error with the stream generation that produced it.""" + + def __init__(self, exception: Exception, generation: int): + self.exception = exception + self.generation = generation + + +class _StreamEnd: + """Signals the stream closed normally.""" + + pass + + +class _StreamMultiplexer: + """Multiplexes concurrent download tasks over a single bidi-gRPC stream. + + Routes responses from a background recv loop to per-task asyncio.Queues + keyed by read_id. Coordinates stream reopening via generation-gated + locking. + + A slow consumer on one task will slow down the entire shared connection + due to bounded queue backpressure propagating through gRPC flow control. + """ + + def __init__( + self, + stream: _AsyncReadObjectStream, + queue_max_size: int = _DEFAULT_QUEUE_MAX_SIZE, + ): + self._stream = stream + self._stream_generation: int = 0 + self._queues: Dict[int, asyncio.Queue] = {} + self._reopen_lock = asyncio.Lock() + self._recv_task: Optional[asyncio.Task] = None + self._queue_max_size = queue_max_size + + @property + def stream_generation(self) -> int: + return self._stream_generation + + def register(self, read_ids: Set[int]) -> asyncio.Queue: + """Register read_ids for a task and return its response queue.""" + queue = asyncio.Queue(maxsize=self._queue_max_size) + for read_id in read_ids: + self._queues[read_id] = queue + return queue + + def unregister(self, read_ids: Set[int]) -> None: + """Remove read_ids from routing.""" + for read_id in read_ids: + self._queues.pop(read_id, None) + + def _get_unique_queues(self) -> Set[asyncio.Queue]: + return set(self._queues.values()) + + async def _put_with_timeout(self, queue: asyncio.Queue, item) -> None: + try: + await asyncio.wait_for(queue.put(item), timeout=_DEFAULT_PUT_TIMEOUT) + except asyncio.TimeoutError: + if queue not in self._get_unique_queues(): + logger.debug("Dropped item for unregistered queue.") + else: + logger.warning( + "Queue full for too long. Dropping item to prevent multiplexer hang." + ) + + def _ensure_recv_loop(self) -> None: + if self._recv_task is None or self._recv_task.done(): + self._recv_task = asyncio.create_task(self._recv_loop()) + + def _stop_recv_loop(self) -> None: + if self._recv_task and not self._recv_task.done(): + self._recv_task.cancel() + + def _put_error_nowait(self, queue: asyncio.Queue, error: _StreamError) -> None: + while True: + try: + queue.put_nowait(error) + break + except asyncio.QueueFull: + try: + queue.get_nowait() + except asyncio.QueueEmpty: + pass + + async def _recv_loop(self) -> None: + try: + while True: + response = await self._stream.recv() + if response is None: + sentinel = _StreamEnd() + await asyncio.gather( + *( + self._put_with_timeout(queue, sentinel) + for queue in self._get_unique_queues() + ) + ) + return + + if response.object_data_ranges: + queues_to_notify: Set[asyncio.Queue] = set() + for data_range in response.object_data_ranges: + read_id = data_range.read_range.read_id + queue = self._queues.get(read_id) + if queue: + queues_to_notify.add(queue) + await asyncio.gather( + *( + self._put_with_timeout(queue, response) + for queue in queues_to_notify + ) + ) + else: + await asyncio.gather( + *( + self._put_with_timeout(queue, response) + for queue in self._get_unique_queues() + ) + ) + except asyncio.CancelledError: + raise + except Exception as e: + error = _StreamError(e, self._stream_generation) + for queue in self._get_unique_queues(): + self._put_error_nowait(queue, error) + + async def send(self, request: _storage_v2.BidiReadObjectRequest) -> int: + self._ensure_recv_loop() + await self._stream.send(request) + return self._stream_generation + + async def reopen_stream( + self, + broken_generation: int, + stream_factory: Callable[[], Awaitable[_AsyncReadObjectStream]], + ) -> None: + async with self._reopen_lock: + if self._stream_generation != broken_generation: + return + self._stop_recv_loop() + if self._recv_task: + try: + await self._recv_task + except (asyncio.CancelledError, Exception): + pass + error = _StreamError(Exception("Stream reopening"), self._stream_generation) + for queue in self._get_unique_queues(): + self._put_error_nowait(queue, error) + try: + await self._stream.close() + except Exception: + pass + self._stream = await stream_factory() + self._stream_generation += 1 + self._ensure_recv_loop() + + async def close(self) -> None: + self._stop_recv_loop() + if self._recv_task: + try: + await self._recv_task + except (asyncio.CancelledError, Exception): + pass + error = _StreamError(Exception("Multiplexer closed"), self._stream_generation) + for queue in self._get_unique_queues(): + self._put_error_nowait(queue, error) diff --git a/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py b/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py new file mode 100644 index 000000000000..4bf5bfaf4e3b --- /dev/null +++ b/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py @@ -0,0 +1,503 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from google.cloud import _storage_v2 +from google.cloud.storage.asyncio._stream_multiplexer import ( + _DEFAULT_QUEUE_MAX_SIZE, + _StreamEnd, + _StreamError, + _StreamMultiplexer, +) + + +class TestSentinelTypes: + def test_stream_error_stores_exception_and_generation(self): + exc = ValueError("test") + error = _StreamError(exc, generation=3) + assert error.exception is exc + assert error.generation == 3 + + def test_stream_end_is_instantiable(self): + sentinel = _StreamEnd() + assert isinstance(sentinel, _StreamEnd) + + +class TestStreamMultiplexerInit: + def test_init_sets_stream_and_defaults(self): + mock_stream = AsyncMock() + mux = _StreamMultiplexer(mock_stream) + assert mux._stream is mock_stream + assert mux.stream_generation == 0 + assert mux._queues == {} + assert mux._recv_task is None + assert mux._queue_max_size == _DEFAULT_QUEUE_MAX_SIZE + + def test_init_custom_queue_size(self): + mock_stream = AsyncMock() + mux = _StreamMultiplexer(mock_stream, queue_max_size=50) + assert mux._queue_max_size == 50 + + +def _make_response(read_id, data=b"data", range_end=False): + return _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData(content=data), + read_range=_storage_v2.ReadRange( + read_id=read_id, read_offset=0, read_length=len(data) + ), + range_end=range_end, + ) + ] + ) + + +def _make_multi_range_response(read_ids, data=b"data"): + ranges = [] + for rid in read_ids: + ranges.append( + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData(content=data), + read_range=_storage_v2.ReadRange( + read_id=rid, read_offset=0, read_length=len(data) + ), + ) + ) + return _storage_v2.BidiReadObjectResponse(object_data_ranges=ranges) + + +class TestRegisterUnregister: + def _make_multiplexer(self): + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + return _StreamMultiplexer(mock_stream), mock_stream + + @pytest.mark.asyncio + async def test_register_returns_bounded_queue(self): + mux, _ = self._make_multiplexer() + queue = mux.register({1, 2, 3}) + assert isinstance(queue, asyncio.Queue) + assert queue.maxsize == _DEFAULT_QUEUE_MAX_SIZE + mux.unregister({1, 2, 3}) + + @pytest.mark.asyncio + async def test_register_maps_read_ids_to_same_queue(self): + mux, _ = self._make_multiplexer() + queue = mux.register({10, 20}) + assert mux._queues[10] is queue + assert mux._queues[20] is queue + mux.unregister({10, 20}) + + @pytest.mark.asyncio + async def test_register_does_not_start_recv_loop(self): + mux, _ = self._make_multiplexer() + assert mux._recv_task is None + mux.register({1}) + assert mux._recv_task is None + mux.unregister({1}) + + @pytest.mark.asyncio + async def test_two_registers_get_separate_queues(self): + mux, _ = self._make_multiplexer() + q1 = mux.register({1}) + q2 = mux.register({2}) + assert q1 is not q2 + assert mux._queues[1] is q1 + assert mux._queues[2] is q2 + mux.unregister({1, 2}) + + @pytest.mark.asyncio + async def test_unregister_removes_read_ids(self): + mux, _ = self._make_multiplexer() + mux.register({1, 2}) + mux.unregister({1}) + assert 1 not in mux._queues + assert 2 in mux._queues + mux.unregister({2}) + + @pytest.mark.asyncio + async def test_unregister_all_does_not_stop_recv_loop(self): + mux, _ = self._make_multiplexer() + mux.register({1}) + mux._ensure_recv_loop() + recv_task = mux._recv_task + assert recv_task is not None + mux.unregister({1}) + await asyncio.sleep(0) + assert not recv_task.cancelled() + + @pytest.mark.asyncio + async def test_unregister_nonexistent_is_noop(self): + mux, _ = self._make_multiplexer() + mux.register({1}) + mux.unregister({999}) + assert 1 in mux._queues + mux.unregister({1}) + + +class TestRecvLoop: + @pytest.mark.asyncio + async def test_routes_response_by_read_id(self): + mock_stream = AsyncMock() + resp1 = _make_response(read_id=10, data=b"hello") + resp2 = _make_response(read_id=20, data=b"world") + mock_stream.recv = AsyncMock(side_effect=[resp1, resp2, None]) + + mux = _StreamMultiplexer(mock_stream) + q1 = mux.register({10}) + q2 = mux.register({20}) + mux._ensure_recv_loop() + + item1 = await asyncio.wait_for(q1.get(), timeout=1) + item2 = await asyncio.wait_for(q2.get(), timeout=1) + + assert item1 is resp1 + assert item2 is resp2 + end1 = await asyncio.wait_for(q1.get(), timeout=1) + end2 = await asyncio.wait_for(q2.get(), timeout=1) + assert isinstance(end1, _StreamEnd) + assert isinstance(end2, _StreamEnd) + mux.unregister({10, 20}) + + @pytest.mark.asyncio + async def test_deduplicates_when_multiple_read_ids_map_to_same_queue(self): + mock_stream = AsyncMock() + resp = _make_multi_range_response([10, 11]) + mock_stream.recv = AsyncMock(side_effect=[resp, None]) + + mux = _StreamMultiplexer(mock_stream) + queue = mux.register({10, 11}) + mux._ensure_recv_loop() + + item = await asyncio.wait_for(queue.get(), timeout=1) + assert item is resp + end = await asyncio.wait_for(queue.get(), timeout=1) + assert isinstance(end, _StreamEnd) + mux.unregister({10, 11}) + + @pytest.mark.asyncio + async def test_metadata_only_response_broadcast_to_all(self): + mock_stream = AsyncMock() + metadata_resp = _storage_v2.BidiReadObjectResponse( + read_handle=_storage_v2.BidiReadHandle(handle=b"handle") + ) + mock_stream.recv = AsyncMock(side_effect=[metadata_resp, None]) + + mux = _StreamMultiplexer(mock_stream) + q1 = mux.register({10}) + q2 = mux.register({20}) + mux._ensure_recv_loop() + + item1 = await asyncio.wait_for(q1.get(), timeout=1) + item2 = await asyncio.wait_for(q2.get(), timeout=1) + assert item1 is metadata_resp + assert item2 is metadata_resp + mux.unregister({10, 20}) + + @pytest.mark.asyncio + async def test_stream_end_sends_sentinel_to_all_queues(self): + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(return_value=None) + + mux = _StreamMultiplexer(mock_stream) + q1 = mux.register({10}) + q2 = mux.register({20}) + mux._ensure_recv_loop() + + end1 = await asyncio.wait_for(q1.get(), timeout=1) + end2 = await asyncio.wait_for(q2.get(), timeout=1) + assert isinstance(end1, _StreamEnd) + assert isinstance(end2, _StreamEnd) + mux.unregister({10, 20}) + + @pytest.mark.asyncio + async def test_error_broadcasts_stream_error_to_all_queues(self): + mock_stream = AsyncMock() + exc = RuntimeError("stream broke") + mock_stream.recv = AsyncMock(side_effect=exc) + + mux = _StreamMultiplexer(mock_stream) + q1 = mux.register({10}) + q2 = mux.register({20}) + mux._ensure_recv_loop() + + await asyncio.sleep(0.05) + + err1 = q1.get_nowait() + err2 = q2.get_nowait() + assert isinstance(err1, _StreamError) + assert err1.exception is exc + assert err1.generation == 0 + assert isinstance(err2, _StreamError) + assert err2.exception is exc + mux.unregister({10, 20}) + + @pytest.mark.asyncio + async def test_error_uses_put_nowait(self): + mock_stream = AsyncMock() + exc = RuntimeError("broke") + mock_stream.recv = AsyncMock(side_effect=exc) + + mux = _StreamMultiplexer(mock_stream, queue_max_size=1) + queue = mux.register({10}) + queue.put_nowait("filler") + mux._ensure_recv_loop() + + await asyncio.sleep(0.05) + + # Queue is full (maxsize=1), but _put_error_nowait pops existing items + # to ensure the error gets recorded. + assert queue.qsize() == 1 + err = queue.get_nowait() + assert isinstance(err, _StreamError) + assert err.exception is exc + mux.unregister({10}) + + @pytest.mark.asyncio + async def test_unknown_read_id_is_dropped(self): + mock_stream = AsyncMock() + resp = _make_response(read_id=999) + mock_stream.recv = AsyncMock(side_effect=[resp, None]) + + mux = _StreamMultiplexer(mock_stream) + queue = mux.register({10}) + mux._ensure_recv_loop() + + end = await asyncio.wait_for(queue.get(), timeout=1) + assert isinstance(end, _StreamEnd) + mux.unregister({10}) + + +class TestSend: + @pytest.mark.asyncio + async def test_send_forwards_to_stream(self): + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + + request = _storage_v2.BidiReadObjectRequest( + read_ranges=[ + _storage_v2.ReadRange(read_id=1, read_offset=0, read_length=10) + ] + ) + gen = await mux.send(request) + mock_stream.send.assert_called_once_with(request) + assert gen == 0 + + @pytest.mark.asyncio + async def test_send_returns_current_generation(self): + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + mux._stream_generation = 5 + + request = _storage_v2.BidiReadObjectRequest() + gen = await mux.send(request) + assert gen == 5 + + @pytest.mark.asyncio + async def test_send_propagates_exception(self): + mock_stream = AsyncMock() + mock_stream.send = AsyncMock(side_effect=RuntimeError("send failed")) + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + + with pytest.raises(RuntimeError, match="send failed"): + await mux.send(_storage_v2.BidiReadObjectRequest()) + + +class TestReopenStream: + @pytest.mark.asyncio + async def test_reopen_bumps_generation_and_replaces_stream(self): + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(old_stream) + mux.register({1}) + assert mux.stream_generation == 0 + + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + factory = AsyncMock(return_value=new_stream) + + await mux.reopen_stream(0, factory) + + assert mux.stream_generation == 1 + assert mux._stream is new_stream + factory.assert_called_once() + mux.unregister({1}) + + @pytest.mark.asyncio + async def test_reopen_skips_if_generation_mismatch(self): + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + mux._stream_generation = 5 + mux.register({1}) + + factory = AsyncMock() + await mux.reopen_stream(3, factory) + + assert mux.stream_generation == 5 + factory.assert_not_called() + mux.unregister({1}) + + @pytest.mark.asyncio + async def test_reopen_broadcasts_error_before_bump(self): + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(old_stream) + queue = mux.register({1}) + + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + factory = AsyncMock(return_value=new_stream) + + await mux.reopen_stream(0, factory) + + err = queue.get_nowait() + assert isinstance(err, _StreamError) + assert err.generation == 0 + mux.unregister({1}) + + @pytest.mark.asyncio + async def test_reopen_starts_new_recv_loop(self): + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(old_stream) + mux.register({1}) + old_recv_task = mux._recv_task + + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + factory = AsyncMock(return_value=new_stream) + + await mux.reopen_stream(0, factory) + + assert mux._recv_task is not old_recv_task + assert not mux._recv_task.done() + mux.unregister({1}) + + @pytest.mark.asyncio + async def test_reopen_closes_old_stream_best_effort(self): + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + old_stream.close = AsyncMock(side_effect=RuntimeError("close failed")) + mux = _StreamMultiplexer(old_stream) + mux.register({1}) + + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + factory = AsyncMock(return_value=new_stream) + + await mux.reopen_stream(0, factory) + assert mux.stream_generation == 1 + mux.unregister({1}) + + @pytest.mark.asyncio + async def test_concurrent_reopen_only_one_wins(self): + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(old_stream) + mux.register({1}) + + call_count = 0 + + async def counting_factory(): + nonlocal call_count + call_count += 1 + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + return new_stream + + await asyncio.gather( + mux.reopen_stream(0, counting_factory), + mux.reopen_stream(0, counting_factory), + ) + + assert call_count == 1 + assert mux.stream_generation == 1 + mux.unregister({1}) + + @pytest.mark.asyncio + async def test_reopen_factory_failure_leaves_generation_unchanged(self): + """If stream_factory raises, generation is not bumped and recv loop + is not restarted. The caller's retry manager will re-attempt reopen + with the same generation, which will succeed because the generation + check still matches.""" + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(old_stream) + mux.register({1}) + + failing_factory = AsyncMock(side_effect=RuntimeError("open failed")) + + with pytest.raises(RuntimeError, match="open failed"): + await mux.reopen_stream(0, failing_factory) + + # Generation was NOT bumped + assert mux.stream_generation == 0 + # Recv loop was stopped and not restarted + assert mux._recv_task is None or mux._recv_task.done() + + # A subsequent reopen with the same generation succeeds + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + ok_factory = AsyncMock(return_value=new_stream) + + await mux.reopen_stream(0, ok_factory) + + assert mux.stream_generation == 1 + assert mux._stream is new_stream + assert mux._recv_task is not None and not mux._recv_task.done() + mux.unregister({1}) + + +class TestClose: + @pytest.mark.asyncio + async def test_close_cancels_recv_loop(self): + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + mux.register({1}) + mux._ensure_recv_loop() + recv_task = mux._recv_task + + await mux.close() + assert recv_task.cancelled() + + @pytest.mark.asyncio + async def test_close_broadcasts_terminal_error(self): + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + q1 = mux.register({1}) + q2 = mux.register({2}) + + await mux.close() + + err1 = q1.get_nowait() + err2 = q2.get_nowait() + assert isinstance(err1, _StreamError) + assert isinstance(err2, _StreamError) + + @pytest.mark.asyncio + async def test_close_with_no_tasks_is_noop(self): + mock_stream = AsyncMock() + mux = _StreamMultiplexer(mock_stream) + await mux.close() # should not raise From bac705ad3716351d2d85e24e792e7595d8f6140f Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Thu, 2 Apr 2026 09:45:30 +0000 Subject: [PATCH 02/18] feat: integrate _StreamMultiplexer into AsyncMultiRangeDownloader --- .../asyncio/async_multi_range_downloader.py | 222 ++++++++-------- .../test_async_multi_range_downloader.py | 238 +++++++++--------- 2 files changed, 238 insertions(+), 222 deletions(-) diff --git a/packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py b/packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py index cea21cb9ae66..9c0fdf05098a 100644 --- a/packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py +++ b/packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py @@ -25,6 +25,11 @@ from google.cloud import _storage_v2 from google.cloud.storage._helpers import generate_random_56_bit_integer +from google.cloud.storage.asyncio._stream_multiplexer import ( + _StreamEnd, + _StreamError, + _StreamMultiplexer, +) from google.cloud.storage.asyncio.async_grpc_client import ( AsyncGrpcClient, ) @@ -224,9 +229,7 @@ def __init__( self.read_obj_str: Optional[_AsyncReadObjectStream] = None self._is_stream_open: bool = False self._routing_token: Optional[str] = None - self._read_id_to_writable_buffer_dict = {} - self._read_id_to_download_ranges_id = {} - self._download_ranges_id_to_pending_read_ids = {} + self._multiplexer: Optional[_StreamMultiplexer] = None self.persisted_size: Optional[int] = None # updated after opening the stream self._open_retries: int = 0 @@ -328,6 +331,45 @@ async def _do_open(): self._is_stream_open = True await retry_policy(_do_open)() + self._multiplexer = _StreamMultiplexer(self.read_obj_str) + + def _create_stream_factory(self, state, metadata): + """Create a factory that opens a new stream with current routing state.""" + + async def factory(): + current_handle = state.get("read_handle") + current_token = state.get("routing_token") + + stream = _AsyncReadObjectStream( + client=self.client.grpc_client, + bucket_name=self.bucket_name, + object_name=self.object_name, + generation_number=self.generation, + read_handle=current_handle, + ) + + current_metadata = list(metadata) if metadata else [] + if current_token: + current_metadata.append( + ( + "x-goog-request-params", + f"routing_token={current_token}", + ) + ) + + await stream.open(metadata=current_metadata if current_metadata else None) + + if stream.generation_number: + self.generation = stream.generation_number + if stream.read_handle: + self.read_handle = stream.read_handle + + self.read_obj_str = stream + self._is_stream_open = True + + return stream + + return factory async def download_ranges( self, @@ -353,32 +395,8 @@ async def download_ranges( * (0, 0, buffer) : downloads 0 to end , i.e. entire object. * (100, 0, buffer) : downloads from 100 to end. - :type lock: asyncio.Lock - :param lock: (Optional) An asyncio lock to synchronize sends and recvs - on the underlying bidi-GRPC stream. This is required when multiple - coroutines are calling this method concurrently. - - i.e. Example usage with multiple coroutines: - - ``` - lock = asyncio.Lock() - task1 = asyncio.create_task(mrd.download_ranges(ranges1, lock)) - task2 = asyncio.create_task(mrd.download_ranges(ranges2, lock)) - await asyncio.gather(task1, task2) - - ``` - - If user want to call this method serially from multiple coroutines, - then providing a lock is not necessary. - - ``` - await mrd.download_ranges(ranges1) - await mrd.download_ranges(ranges2) - - # ... some other code code... - - ``` + :param lock: (Deprecated) This parameter is deprecated and has no effect. :type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry` :param retry_policy: (Optional) The retry policy to use for the operation. @@ -397,9 +415,6 @@ async def download_ranges( if not self._is_stream_open: raise ValueError("Underlying bidi-gRPC stream is not open") - if lock is None: - lock = asyncio.Lock() - if retry_policy is None: retry_policy = AsyncRetry(predicate=_is_read_retryable) @@ -419,99 +434,97 @@ async def download_ranges( "routing_token": None, } - # Track attempts to manage stream reuse - attempt_count = 0 - - def send_ranges_and_get_bytes( - requests: List[_storage_v2.ReadRange], - state: Dict[str, Any], - metadata: Optional[List[Tuple[str, str]]] = None, - ): - async def generator(): - nonlocal attempt_count - attempt_count += 1 - - if attempt_count > 1: - logger.info( - f"Resuming download (attempt {attempt_count}) for {len(requests)} ranges." - ) + read_ids = set(download_states.keys()) + queue = self._multiplexer.register(read_ids) - async with lock: - current_handle = state.get("read_handle") - current_token = state.get("routing_token") + try: + attempt_count = 0 + last_broken_generation = None - # We reopen if it's a redirect (token exists) OR if this is a retry - # (not first attempt). This prevents trying to send data on a dead - # stream from a previous failed attempt. - should_reopen = ( - (attempt_count > 1) - or (current_token is not None) - or (metadata is not None) - ) + def send_and_recv_via_multiplexer( + requests: List[_storage_v2.ReadRange], + state: Dict[str, Any], + ): + async def generator(): + nonlocal attempt_count, last_broken_generation + attempt_count += 1 - if should_reopen: - if current_token: - logger.info( - f"Re-opening stream with routing token: {current_token}" - ) - - self.read_obj_str = _AsyncReadObjectStream( - client=self.client.grpc_client, - bucket_name=self.bucket_name, - object_name=self.object_name, - generation_number=self.generation, - read_handle=current_handle, + if attempt_count > 1: + logger.info( + f"Resuming download (attempt {attempt_count}) for {len(requests)} ranges." ) - # Inject routing_token into metadata if present - current_metadata = list(metadata) if metadata else [] - if current_token: - current_metadata.append( - ( - "x-goog-request-params", - f"routing_token={current_token}", - ) - ) - - await self.read_obj_str.open( - metadata=current_metadata if current_metadata else None + # Reopen stream if needed + should_reopen = ( + attempt_count > 1 and last_broken_generation is not None + ) or (attempt_count == 1 and metadata is not None) + if should_reopen: + broken_gen = ( + last_broken_generation + if attempt_count > 1 + else self._multiplexer.stream_generation + ) + stream_factory = self._create_stream_factory(state, metadata) + await self._multiplexer.reopen_stream( + broken_gen, stream_factory ) - self._is_stream_open = True - pending_read_ids = {r.read_id for r in requests} + my_generation = self._multiplexer.stream_generation # Send Requests + pending_read_ids = {r.read_id for r in requests} for i in range( 0, len(requests), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST ): batch = requests[i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST] - await self.read_obj_str.send( - _storage_v2.BidiReadObjectRequest(read_ranges=batch) - ) + try: + await self._multiplexer.send( + _storage_v2.BidiReadObjectRequest(read_ranges=batch) + ) + except Exception: + last_broken_generation = my_generation + raise + # Receive Responses while pending_read_ids: - response = await self.read_obj_str.recv() - if response is None: + item = await queue.get() + + if isinstance(item, _StreamEnd): + if pending_read_ids: + last_broken_generation = my_generation + raise exceptions.ServiceUnavailable( + "Stream ended with pending read_ids" + ) break - if response.object_data_ranges: - for data_range in response.object_data_ranges: + + if isinstance(item, _StreamError): + if item.generation < my_generation: + continue # stale error, skip + last_broken_generation = item.generation + raise item.exception + + # Track completion + if item.object_data_ranges: + for data_range in item.object_data_ranges: if data_range.range_end: pending_read_ids.discard( data_range.read_range.read_id ) - yield response + yield item - return generator() + return generator() - strategy = _ReadResumptionStrategy() - retry_manager = _BidiStreamRetryManager( - strategy, lambda r, s: send_ranges_and_get_bytes(r, s, metadata=metadata) - ) + strategy = _ReadResumptionStrategy() + retry_manager = _BidiStreamRetryManager( + strategy, send_and_recv_via_multiplexer + ) - await retry_manager.execute(initial_state, retry_policy) + await retry_manager.execute(initial_state, retry_policy) - if initial_state.get("read_handle"): - self.read_handle = initial_state["read_handle"] + if initial_state.get("read_handle"): + self.read_handle = initial_state["read_handle"] + finally: + self._multiplexer.unregister(read_ids) async def close(self): """ @@ -520,8 +533,15 @@ async def close(self): if not self._is_stream_open: raise ValueError("Underlying bidi-gRPC stream is not open") + if self._multiplexer: + await self._multiplexer.close() + self._multiplexer = None + if self.read_obj_str: - await self.read_obj_str.close() + try: + await self.read_obj_str.close() + except (asyncio.CancelledError, exceptions.GoogleAPICallError): + pass self.read_obj_str = None self._is_stream_open = False diff --git a/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py b/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py index 80df5a438173..9b9f63f32e4f 100644 --- a/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py +++ b/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py @@ -60,6 +60,9 @@ async def _make_mock_mrd( mock_stream.generation_number = _TEST_GENERATION_NUMBER mock_stream.persisted_size = _TEST_OBJECT_SIZE mock_stream.read_handle = _TEST_READ_HANDLE + mock_stream.is_stream_open = True + # Default recv blocks forever (tests override with specific side_effect) + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) mrd = await AsyncMultiRangeDownloader.create_mrd( mock_client, bucket_name, object_name, generation, read_handle @@ -102,69 +105,39 @@ async def test_create_mrd(self, mock_cls_async_read_object_stream): "google.cloud.storage.asyncio.async_multi_range_downloader._AsyncReadObjectStream" ) @pytest.mark.asyncio - async def test_download_ranges_via_async_gather( + async def test_download_ranges( self, mock_cls_async_read_object_stream, mock_random_int ): - # Arrange data = b"these_are_18_chars" crc32c = Checksum(data).digest() crc32c_int = int.from_bytes(crc32c, "big") - crc32c_checksum_for_data_slice = int.from_bytes( - Checksum(data[10:16]).digest(), "big" - ) mock_mrd, _ = await self._make_mock_mrd(mock_cls_async_read_object_stream) - - mock_random_int.side_effect = [456, 91011] + mock_random_int.side_effect = [456] mock_mrd.read_obj_str.send = AsyncMock() - mock_mrd.read_obj_str.recv = AsyncMock() - - mock_mrd.read_obj_str.recv.side_effect = [ - _storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - _storage_v2.ObjectRangeData( - checksummed_data=_storage_v2.ChecksummedData( - content=data, crc32c=crc32c_int - ), - range_end=True, - read_range=_storage_v2.ReadRange( - read_offset=0, read_length=18, read_id=456 - ), - ) - ] - ), - _storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - _storage_v2.ObjectRangeData( - checksummed_data=_storage_v2.ChecksummedData( - content=data[10:16], - crc32c=crc32c_checksum_for_data_slice, - ), - range_end=True, - read_range=_storage_v2.ReadRange( - read_offset=10, read_length=6, read_id=91011 - ), - ) - ], - ), - None, - ] - - # Act - buffer = BytesIO() - second_buffer = BytesIO() - lock = asyncio.Lock() - - task1 = asyncio.create_task(mock_mrd.download_ranges([(0, 18, buffer)], lock)) - task2 = asyncio.create_task( - mock_mrd.download_ranges([(10, 6, second_buffer)], lock) + mock_mrd.read_obj_str.recv = AsyncMock( + side_effect=[ + _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=data, crc32c=crc32c_int + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=0, read_length=18, read_id=456 + ), + ) + ], + ), + None, + ] ) - await asyncio.gather(task1, task2) - # Assert + buffer = BytesIO() + await mock_mrd.download_ranges([(0, 18, buffer)]) assert buffer.getvalue() == data - assert second_buffer.getvalue() == data[10:16] @mock.patch( "google.cloud.storage.asyncio.async_multi_range_downloader.generate_random_56_bit_integer" @@ -173,50 +146,78 @@ async def test_download_ranges_via_async_gather( "google.cloud.storage.asyncio.async_multi_range_downloader._AsyncReadObjectStream" ) @pytest.mark.asyncio - async def test_download_ranges( + async def test_download_ranges_via_async_gather( self, mock_cls_async_read_object_stream, mock_random_int ): - # Arrange data = b"these_are_18_chars" crc32c = Checksum(data).digest() crc32c_int = int.from_bytes(crc32c, "big") + crc32c_checksum_for_data_slice = int.from_bytes( + Checksum(data[10:16]).digest(), "big" + ) mock_mrd, _ = await self._make_mock_mrd(mock_cls_async_read_object_stream) + mock_random_int.side_effect = [456, 91011] - mock_random_int.side_effect = [456] + send_count = 0 + both_sent = asyncio.Event() + + async def counting_send(request): + nonlocal send_count + send_count += 1 + if send_count >= 2: + both_sent.set() + + mock_mrd.read_obj_str.send = AsyncMock(side_effect=counting_send) + + recv_call_count = 0 + + async def controlled_recv(): + nonlocal recv_call_count + recv_call_count += 1 + if recv_call_count == 1: + await both_sent.wait() + return _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=data, crc32c=crc32c_int + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=0, read_length=18, read_id=456 + ), + ) + ] + ) + elif recv_call_count == 2: + return _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=data[10:16], + crc32c=crc32c_checksum_for_data_slice, + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=10, read_length=6, read_id=91011 + ), + ) + ], + ) + return None - mock_mrd.read_obj_str.send = AsyncMock() - mock_mrd.read_obj_str.recv = AsyncMock() - mock_mrd.read_obj_str.recv.side_effect = [ - _storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - _storage_v2.ObjectRangeData( - checksummed_data=_storage_v2.ChecksummedData( - content=data, crc32c=crc32c_int - ), - range_end=True, - read_range=_storage_v2.ReadRange( - read_offset=0, read_length=18, read_id=456 - ), - ) - ], - ), - None, - ] + mock_mrd.read_obj_str.recv = AsyncMock(side_effect=controlled_recv) - # Act buffer = BytesIO() - await mock_mrd.download_ranges([(0, 18, buffer)]) + second_buffer = BytesIO() + + task1 = asyncio.create_task(mock_mrd.download_ranges([(0, 18, buffer)])) + task2 = asyncio.create_task(mock_mrd.download_ranges([(10, 6, second_buffer)])) + await asyncio.gather(task1, task2) - # Assert - mock_mrd.read_obj_str.send.assert_called_once_with( - _storage_v2.BidiReadObjectRequest( - read_ranges=[ - _storage_v2.ReadRange(read_offset=0, read_length=18, read_id=456) - ] - ) - ) assert buffer.getvalue() == data + assert second_buffer.getvalue() == data[10:16] @pytest.mark.asyncio async def test_downloading_ranges_with_more_than_1000_should_throw_error(self): @@ -320,6 +321,7 @@ def test_init_raises_if_crc32c_c_extension_is_missing(self, mock_google_crc32c): async def test_download_ranges_raises_on_checksum_mismatch( self, mock_checksum_class ): + from google.cloud.storage.asyncio._stream_multiplexer import _StreamMultiplexer from google.cloud.storage.asyncio.async_multi_range_downloader import ( AsyncMultiRangeDownloader, ) @@ -353,6 +355,7 @@ async def test_download_ranges_raises_on_checksum_mismatch( mrd = AsyncMultiRangeDownloader(mock_client, "bucket", "object") mrd.read_obj_str = mock_stream mrd._is_stream_open = True + mrd._multiplexer = _StreamMultiplexer(mock_stream) with pytest.raises(DataCorruption) as exc_info: with mock.patch( @@ -419,6 +422,8 @@ async def test_create_mrd_with_generation_number( mock_stream.generation_number = _TEST_GENERATION_NUMBER mock_stream.persisted_size = _TEST_OBJECT_SIZE mock_stream.read_handle = _TEST_READ_HANDLE + mock_stream.is_stream_open = True + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) # Act mrd = await AsyncMultiRangeDownloader.create_mrd( @@ -521,59 +526,50 @@ async def test_on_open_error_logs_warning(self, mock_logger): async def test_download_ranges_resumption_logging( self, mock_cls_async_read_object_stream, mock_random_int, mock_logger ): - # Arrange mock_mrd, _ = await self._make_mock_mrd(mock_cls_async_read_object_stream) - mock_mrd.read_obj_str.send = AsyncMock() - mock_mrd.read_obj_str.recv = AsyncMock() - from google.api_core import exceptions as core_exceptions retryable_exc = core_exceptions.ServiceUnavailable("Retry me") - # mock send to raise exception ONCE then succeed - mock_mrd.read_obj_str.send.side_effect = [ - retryable_exc, - None, # Success on second try - ] - - # mock recv for second try - mock_mrd.read_obj_str.recv.side_effect = [ - _storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - _storage_v2.ObjectRangeData( - checksummed_data=_storage_v2.ChecksummedData( - content=b"data", crc32c=123 - ), - range_end=True, - read_range=_storage_v2.ReadRange( - read_offset=0, read_length=4, read_id=123 - ), - ) - ] - ), - None, - ] + mock_mrd.read_obj_str.send = AsyncMock( + side_effect=[ + retryable_exc, + None, + ] + ) + + recv_call_count = 0 + + async def staged_recv(): + nonlocal recv_call_count + recv_call_count += 1 + if recv_call_count == 1: + return _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=b"data", crc32c=123 + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=0, read_length=4, read_id=123 + ), + ) + ] + ) + return None + + mock_mrd.read_obj_str.recv = AsyncMock(side_effect=staged_recv) + mock_mrd.read_obj_str.is_stream_open = True mock_random_int.return_value = 123 - # Act buffer = BytesIO() - # Patch Checksum where it is likely used (reads_resumption_strategy or similar), - # but actually if we use google_crc32c directly, we should patch that or provide valid CRC. - # Since we can't reliably predict where Checksum is imported/used without more digging, - # let's provide a valid CRC for b"data". - # Checksum(b"data").digest() -> needs to match crc32c=123. - # But we can't force b"data" to have crc=123. - # So we MUST patch Checksum. - # It is used in google.cloud.storage.asyncio.retry.reads_resumption_strategy - with mock.patch( "google.cloud.storage.asyncio.retry.reads_resumption_strategy.Checksum" ) as mock_chk: mock_chk.return_value.digest.return_value = (123).to_bytes(4, "big") - await mock_mrd.download_ranges([(0, 4, buffer)]) - # Assert mock_logger.info.assert_any_call("Resuming download (attempt 2) for 1 ranges.") From 2c442918fa3ea453114e3213d595e2cdc9d10c9a Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Thu, 2 Apr 2026 09:45:33 +0000 Subject: [PATCH 03/18] test: add system tests for concurrent AsyncMultiRangeDownloader --- .../tests/system/test_zonal.py | 194 +++++++++++++++++- 1 file changed, 193 insertions(+), 1 deletion(-) diff --git a/packages/google-cloud-storage/tests/system/test_zonal.py b/packages/google-cloud-storage/tests/system/test_zonal.py index edd323b037ec..f30d627ae033 100644 --- a/packages/google-cloud-storage/tests/system/test_zonal.py +++ b/packages/google-cloud-storage/tests/system/test_zonal.py @@ -2,13 +2,14 @@ import asyncio import gc import os +import random import uuid from io import BytesIO # python additional imports import google_crc32c import pytest -from google.api_core.exceptions import FailedPrecondition, NotFound +from google.api_core.exceptions import FailedPrecondition, NotFound, OutOfRange from google.cloud.storage.asyncio.async_appendable_object_writer import ( _DEFAULT_FLUSH_INTERVAL_BYTES, @@ -594,3 +595,194 @@ async def _run(): gc.collect() event_loop.run_until_complete(_run()) + + +def test_mrd_concurrent_download( + storage_client, blobs_to_delete, event_loop, grpc_client +): + """ + Test that mrd can handle concurrent `download_ranges` calls correctly. + Tests overlapping ranges, high concurrency (len > 100 multiplexing batch limits), + mixed random chunk sizes (small/medium/large), and full object fetching alongside specific chunks. + """ + object_size = 15 * 1024 * 1024 # 15MB + object_name = f"test_mrd_concurrent-{uuid.uuid4()}" + + async def _run(): + object_data = os.urandom(object_size) + + writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + await writer.open() + await writer.append(object_data) + await writer.close(finalize_on_close=True) + + async with AsyncMultiRangeDownloader( + grpc_client, _ZONAL_BUCKET, object_name + ) as mrd: + tasks = [] + ranges_to_fetch = [] + + # Overlapping ranges & Mixed random chunk sizes + # Small chunks + for _ in range(60): + start = random.randint(0, object_size - 100) + length = random.randint(1, 100) + ranges_to_fetch.append((start, length)) + # Medium chunks + for _ in range(60): + start = random.randint(0, object_size - 100000) + length = random.randint(100, 100000) + ranges_to_fetch.append((start, length)) + # Large chunks + for _ in range(5): + start = random.randint(0, object_size - 2000000) + length = random.randint(1000000, 2000000) + ranges_to_fetch.append((start, length)) + + # Full object fetching concurrently + ranges_to_fetch.append((0, 0)) + + # High concurrency batching (Total > 100 ranges) + assert len(ranges_to_fetch) > 100 + random.shuffle(ranges_to_fetch) + + buffers = [BytesIO() for _ in range(len(ranges_to_fetch))] + + for idx, (start, length) in enumerate(ranges_to_fetch): + tasks.append( + asyncio.create_task( + mrd.download_ranges([(start, length, buffers[idx])]) + ) + ) + + await asyncio.gather(*tasks) + + # Validation + for idx, (start, length) in enumerate(ranges_to_fetch): + if length == 0: + expected_data = object_data[start:] + else: + expected_data = object_data[start : start + length] + assert buffers[idx].getvalue() == expected_data + + del writer + gc.collect() + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + + event_loop.run_until_complete(_run()) + + +def test_mrd_concurrent_download_cancellation( + storage_client, blobs_to_delete, event_loop, grpc_client +): + """ + Test task cancellation / abort mid-stream. + Tests that downloading gracefully manages memory and internal references + when tasks are canceled during active multiplexing, without breaking remaining downloads. + """ + object_size = 5 * 1024 * 1024 # 5MB + object_name = f"test_mrd_cancel-{uuid.uuid4()}" + + async def _run(): + object_data = os.urandom(object_size) + + writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + await writer.open() + await writer.append(object_data) + await writer.close(finalize_on_close=True) + + async with AsyncMultiRangeDownloader( + grpc_client, _ZONAL_BUCKET, object_name + ) as mrd: + tasks = [] + num_chunks = 100 + chunk_size = object_size // num_chunks + buffers = [BytesIO() for _ in range(num_chunks)] + + for i in range(num_chunks): + start = i * chunk_size + tasks.append( + asyncio.create_task( + mrd.download_ranges([(start, chunk_size, buffers[i])]) + ) + ) + + # Let the loop start sending Bidi requests + await asyncio.sleep(0.01) + + # Cancel a subset of evenly distributed tasks + for i in range(0, num_chunks, 2): + tasks[i].cancel() + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i in range(num_chunks): + if i % 2 == 0: + assert isinstance(results[i], asyncio.CancelledError) + else: + start = i * chunk_size + expected_data = object_data[start : start + chunk_size] + assert buffers[i].getvalue() == expected_data + + del writer + gc.collect() + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + + event_loop.run_until_complete(_run()) + + +def test_mrd_concurrent_download_out_of_bounds( + storage_client, blobs_to_delete, event_loop, grpc_client +): + """ + Test out-of-bounds & edge ranges concurrent with valid requests. + Verifies isolation: invalid bounds generate correct exceptions and don't stall the stream + for concurrently valid requests. + """ + object_size = 2 * 1024 * 1024 # 2MB + object_name = f"test_mrd_oob-{uuid.uuid4()}" + + async def _run(): + object_data = os.urandom(object_size) + + writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + await writer.open() + await writer.append(object_data) + await writer.close(finalize_on_close=True) + + async with AsyncMultiRangeDownloader( + grpc_client, _ZONAL_BUCKET, object_name + ) as mrd: + b_valid = BytesIO() + t_valid = asyncio.create_task(mrd.download_ranges([(0, 100, b_valid)])) + + b_oob1 = BytesIO() + t_oob1 = asyncio.create_task( + mrd.download_ranges([(object_size + 1000, 100, b_oob1)]) + ) + + # EOF ask for 100 bytes + b_oob2 = BytesIO() + t_oob2 = asyncio.create_task( + mrd.download_ranges([(object_size, 100, b_oob2)]) + ) + + results = await asyncio.gather( + t_valid, t_oob1, t_oob2, return_exceptions=True + ) + + # Verify valid one processed correctly + assert b_valid.getvalue() == object_data[:100] + + # Verify fully OOB request returned Exception + assert isinstance(results[1], OutOfRange) + + # Verify request exactly at EOF successfully completed with 0 bytes + assert not isinstance(results[2], Exception) + assert b_oob2.getvalue() == b"" + + del writer + gc.collect() + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + + event_loop.run_until_complete(_run()) From d6c9bfee1c809950f55e3bfca9409ad157e75865 Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Wed, 8 Apr 2026 01:33:04 +0000 Subject: [PATCH 04/18] test: use grpc_client_direct in AsyncMultiRangeDownloader system tests --- .../tests/system/test_zonal.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/packages/google-cloud-storage/tests/system/test_zonal.py b/packages/google-cloud-storage/tests/system/test_zonal.py index f30d627ae033..eb0712e12663 100644 --- a/packages/google-cloud-storage/tests/system/test_zonal.py +++ b/packages/google-cloud-storage/tests/system/test_zonal.py @@ -598,7 +598,7 @@ async def _run(): def test_mrd_concurrent_download( - storage_client, blobs_to_delete, event_loop, grpc_client + storage_client, blobs_to_delete, event_loop, grpc_client_direct ): """ Test that mrd can handle concurrent `download_ranges` calls correctly. @@ -611,13 +611,13 @@ def test_mrd_concurrent_download( async def _run(): object_data = os.urandom(object_size) - writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + writer = AsyncAppendableObjectWriter(grpc_client_direct, _ZONAL_BUCKET, object_name) await writer.open() await writer.append(object_data) await writer.close(finalize_on_close=True) async with AsyncMultiRangeDownloader( - grpc_client, _ZONAL_BUCKET, object_name + grpc_client_direct, _ZONAL_BUCKET, object_name ) as mrd: tasks = [] ranges_to_fetch = [] @@ -673,7 +673,7 @@ async def _run(): def test_mrd_concurrent_download_cancellation( - storage_client, blobs_to_delete, event_loop, grpc_client + storage_client, blobs_to_delete, event_loop, grpc_client_direct ): """ Test task cancellation / abort mid-stream. @@ -686,13 +686,13 @@ def test_mrd_concurrent_download_cancellation( async def _run(): object_data = os.urandom(object_size) - writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + writer = AsyncAppendableObjectWriter(grpc_client_direct, _ZONAL_BUCKET, object_name) await writer.open() await writer.append(object_data) await writer.close(finalize_on_close=True) async with AsyncMultiRangeDownloader( - grpc_client, _ZONAL_BUCKET, object_name + grpc_client_direct, _ZONAL_BUCKET, object_name ) as mrd: tasks = [] num_chunks = 100 @@ -732,7 +732,7 @@ async def _run(): def test_mrd_concurrent_download_out_of_bounds( - storage_client, blobs_to_delete, event_loop, grpc_client + storage_client, blobs_to_delete, event_loop, grpc_client_direct ): """ Test out-of-bounds & edge ranges concurrent with valid requests. @@ -745,13 +745,13 @@ def test_mrd_concurrent_download_out_of_bounds( async def _run(): object_data = os.urandom(object_size) - writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + writer = AsyncAppendableObjectWriter(grpc_client_direct, _ZONAL_BUCKET, object_name) await writer.open() await writer.append(object_data) await writer.close(finalize_on_close=True) async with AsyncMultiRangeDownloader( - grpc_client, _ZONAL_BUCKET, object_name + grpc_client_direct, _ZONAL_BUCKET, object_name ) as mrd: b_valid = BytesIO() t_valid = asyncio.create_task(mrd.download_ranges([(0, 100, b_valid)])) From f2fb305f78197d66993a89e1fbafa90ac552cb8f Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Wed, 8 Apr 2026 01:35:37 +0000 Subject: [PATCH 05/18] feat: rename _DEFAULT_PUT_TIMEOUT to _DEFAULT_PUT_TIMEOUT_SECONDS in _stream_multiplexer.py --- .../google/cloud/storage/asyncio/_stream_multiplexer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py index abd99e087cd3..6e35ffac0b44 100644 --- a/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py +++ b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) _DEFAULT_QUEUE_MAX_SIZE = 100 -_DEFAULT_PUT_TIMEOUT = 20.0 +_DEFAULT_PUT_TIMEOUT_SECONDS = 20.0 class _StreamError: @@ -87,7 +87,7 @@ def _get_unique_queues(self) -> Set[asyncio.Queue]: async def _put_with_timeout(self, queue: asyncio.Queue, item) -> None: try: - await asyncio.wait_for(queue.put(item), timeout=_DEFAULT_PUT_TIMEOUT) + await asyncio.wait_for(queue.put(item), timeout=_DEFAULT_PUT_TIMEOUT_SECONDS) except asyncio.TimeoutError: if queue not in self._get_unique_queues(): logger.debug("Dropped item for unregistered queue.") From f949b67e1aa90192f8505f34162632925f467039 Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Wed, 8 Apr 2026 01:45:06 +0000 Subject: [PATCH 06/18] chore: update copyright year to 2026 --- .../google/cloud/storage/asyncio/_stream_multiplexer.py | 4 ++-- .../tests/unit/asyncio/test_stream_multiplexer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py index 6e35ffac0b44..7725ceceffde 100644 --- a/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py +++ b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py @@ -1,10 +1,10 @@ -# Copyright 2025 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# https://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py b/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py index 4bf5bfaf4e3b..2549ab5bf107 100644 --- a/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py +++ b/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py @@ -1,4 +1,4 @@ -# Copyright 2025 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 83a98f47ccb17c1ee5ec59724c99621aa8a3e213 Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Wed, 8 Apr 2026 01:50:12 +0000 Subject: [PATCH 07/18] feat: rename my_generation to stream_generation in async_multi_range_downloader.py --- .../cloud/storage/asyncio/async_multi_range_downloader.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py b/packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py index 9c0fdf05098a..ac0844519e2d 100644 --- a/packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py +++ b/packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py @@ -469,7 +469,7 @@ async def generator(): broken_gen, stream_factory ) - my_generation = self._multiplexer.stream_generation + stream_generation = self._multiplexer.stream_generation # Send Requests pending_read_ids = {r.read_id for r in requests} @@ -482,7 +482,7 @@ async def generator(): _storage_v2.BidiReadObjectRequest(read_ranges=batch) ) except Exception: - last_broken_generation = my_generation + last_broken_generation = stream_generation raise # Receive Responses @@ -491,14 +491,14 @@ async def generator(): if isinstance(item, _StreamEnd): if pending_read_ids: - last_broken_generation = my_generation + last_broken_generation = stream_generation raise exceptions.ServiceUnavailable( "Stream ended with pending read_ids" ) break if isinstance(item, _StreamError): - if item.generation < my_generation: + if item.generation < stream_generation: continue # stale error, skip last_broken_generation = item.generation raise item.exception From 1092bacf9a67d7425dcff94c26de50424102bffe Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Wed, 8 Apr 2026 02:03:54 +0000 Subject: [PATCH 08/18] test: parametrize test_mrd_concurrent_download for different chunk sizes --- .../tests/system/test_zonal.py | 48 ++++++++++--------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/packages/google-cloud-storage/tests/system/test_zonal.py b/packages/google-cloud-storage/tests/system/test_zonal.py index eb0712e12663..d5e8c1c2d60e 100644 --- a/packages/google-cloud-storage/tests/system/test_zonal.py +++ b/packages/google-cloud-storage/tests/system/test_zonal.py @@ -597,13 +597,27 @@ async def _run(): event_loop.run_until_complete(_run()) +@pytest.mark.parametrize( + "ranges_desc, chunk_ranges", + [ + ("small", [(1, 100)] * 3), + ("medium", [(100, 100000)] * 3), + ("large", [(1000000, 2000000)] * 3), + ("mixed", [(1, 100), (100, 100000), (1000000, 2000000)]), + ], +) def test_mrd_concurrent_download( - storage_client, blobs_to_delete, event_loop, grpc_client_direct + storage_client, + blobs_to_delete, + event_loop, + grpc_client_direct, + ranges_desc, + chunk_ranges, ): """ Test that mrd can handle concurrent `download_ranges` calls correctly. - Tests overlapping ranges, high concurrency (len > 100 multiplexing batch limits), - mixed random chunk sizes (small/medium/large), and full object fetching alongside specific chunks. + Tests overlapping ranges, minimal concurrency, + parametrized chunk sizes (small/medium/large/mixed), and full object fetching alongside specific chunks. """ object_size = 15 * 1024 * 1024 # 15MB object_name = f"test_mrd_concurrent-{uuid.uuid4()}" @@ -611,7 +625,9 @@ def test_mrd_concurrent_download( async def _run(): object_data = os.urandom(object_size) - writer = AsyncAppendableObjectWriter(grpc_client_direct, _ZONAL_BUCKET, object_name) + writer = AsyncAppendableObjectWriter( + grpc_client_direct, _ZONAL_BUCKET, object_name + ) await writer.open() await writer.append(object_data) await writer.close(finalize_on_close=True) @@ -622,28 +638,14 @@ async def _run(): tasks = [] ranges_to_fetch = [] - # Overlapping ranges & Mixed random chunk sizes - # Small chunks - for _ in range(60): - start = random.randint(0, object_size - 100) - length = random.randint(1, 100) - ranges_to_fetch.append((start, length)) - # Medium chunks - for _ in range(60): - start = random.randint(0, object_size - 100000) - length = random.randint(100, 100000) - ranges_to_fetch.append((start, length)) - # Large chunks - for _ in range(5): - start = random.randint(0, object_size - 2000000) - length = random.randint(1000000, 2000000) + for min_len, max_len in chunk_ranges: + start = random.randint(0, object_size - max_len) + length = random.randint(min_len, max_len) ranges_to_fetch.append((start, length)) # Full object fetching concurrently ranges_to_fetch.append((0, 0)) - # High concurrency batching (Total > 100 ranges) - assert len(ranges_to_fetch) > 100 random.shuffle(ranges_to_fetch) buffers = [BytesIO() for _ in range(len(ranges_to_fetch))] @@ -667,7 +669,9 @@ async def _run(): del writer gc.collect() - blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + blobs_to_delete.append( + storage_client.bucket(_ZONAL_BUCKET).blob(object_name) + ) event_loop.run_until_complete(_run()) From 18487414a5ed489a4c720809c94cc1c05866a327 Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Wed, 8 Apr 2026 02:12:16 +0000 Subject: [PATCH 09/18] chore: format tests/unit/asyncio/test_async_multi_range_downloader.py --- .../test_async_multi_range_downloader.py | 82 +++++++++---------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py b/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py index 9b9f63f32e4f..19847db015eb 100644 --- a/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py +++ b/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py @@ -98,47 +98,6 @@ async def test_create_mrd(self, mock_cls_async_read_object_stream): assert mrd.is_stream_open assert mrd._open_retries == 0 - @mock.patch( - "google.cloud.storage.asyncio.async_multi_range_downloader.generate_random_56_bit_integer" - ) - @mock.patch( - "google.cloud.storage.asyncio.async_multi_range_downloader._AsyncReadObjectStream" - ) - @pytest.mark.asyncio - async def test_download_ranges( - self, mock_cls_async_read_object_stream, mock_random_int - ): - data = b"these_are_18_chars" - crc32c = Checksum(data).digest() - crc32c_int = int.from_bytes(crc32c, "big") - - mock_mrd, _ = await self._make_mock_mrd(mock_cls_async_read_object_stream) - mock_random_int.side_effect = [456] - - mock_mrd.read_obj_str.send = AsyncMock() - mock_mrd.read_obj_str.recv = AsyncMock( - side_effect=[ - _storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - _storage_v2.ObjectRangeData( - checksummed_data=_storage_v2.ChecksummedData( - content=data, crc32c=crc32c_int - ), - range_end=True, - read_range=_storage_v2.ReadRange( - read_offset=0, read_length=18, read_id=456 - ), - ) - ], - ), - None, - ] - ) - - buffer = BytesIO() - await mock_mrd.download_ranges([(0, 18, buffer)]) - assert buffer.getvalue() == data - @mock.patch( "google.cloud.storage.asyncio.async_multi_range_downloader.generate_random_56_bit_integer" ) @@ -219,6 +178,47 @@ async def controlled_recv(): assert buffer.getvalue() == data assert second_buffer.getvalue() == data[10:16] + @mock.patch( + "google.cloud.storage.asyncio.async_multi_range_downloader.generate_random_56_bit_integer" + ) + @mock.patch( + "google.cloud.storage.asyncio.async_multi_range_downloader._AsyncReadObjectStream" + ) + @pytest.mark.asyncio + async def test_download_ranges( + self, mock_cls_async_read_object_stream, mock_random_int + ): + data = b"these_are_18_chars" + crc32c = Checksum(data).digest() + crc32c_int = int.from_bytes(crc32c, "big") + + mock_mrd, _ = await self._make_mock_mrd(mock_cls_async_read_object_stream) + mock_random_int.side_effect = [456] + + mock_mrd.read_obj_str.send = AsyncMock() + mock_mrd.read_obj_str.recv = AsyncMock( + side_effect=[ + _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=data, crc32c=crc32c_int + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=0, read_length=18, read_id=456 + ), + ) + ], + ), + None, + ] + ) + + buffer = BytesIO() + await mock_mrd.download_ranges([(0, 18, buffer)]) + assert buffer.getvalue() == data + @pytest.mark.asyncio async def test_downloading_ranges_with_more_than_1000_should_throw_error(self): # Arrange From 9d010308fea6ddf181867b47d6cb2f72719a3ebd Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Wed, 8 Apr 2026 02:13:33 +0000 Subject: [PATCH 10/18] chore: format _stream_multiplexer.py and test_zonal.py --- .../cloud/storage/asyncio/_stream_multiplexer.py | 4 +++- .../google-cloud-storage/tests/system/test_zonal.py | 12 +++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py index 7725ceceffde..caabb5b8f089 100644 --- a/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py +++ b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py @@ -87,7 +87,9 @@ def _get_unique_queues(self) -> Set[asyncio.Queue]: async def _put_with_timeout(self, queue: asyncio.Queue, item) -> None: try: - await asyncio.wait_for(queue.put(item), timeout=_DEFAULT_PUT_TIMEOUT_SECONDS) + await asyncio.wait_for( + queue.put(item), timeout=_DEFAULT_PUT_TIMEOUT_SECONDS + ) except asyncio.TimeoutError: if queue not in self._get_unique_queues(): logger.debug("Dropped item for unregistered queue.") diff --git a/packages/google-cloud-storage/tests/system/test_zonal.py b/packages/google-cloud-storage/tests/system/test_zonal.py index d5e8c1c2d60e..3c9e7583a46a 100644 --- a/packages/google-cloud-storage/tests/system/test_zonal.py +++ b/packages/google-cloud-storage/tests/system/test_zonal.py @@ -669,9 +669,7 @@ async def _run(): del writer gc.collect() - blobs_to_delete.append( - storage_client.bucket(_ZONAL_BUCKET).blob(object_name) - ) + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) event_loop.run_until_complete(_run()) @@ -690,7 +688,9 @@ def test_mrd_concurrent_download_cancellation( async def _run(): object_data = os.urandom(object_size) - writer = AsyncAppendableObjectWriter(grpc_client_direct, _ZONAL_BUCKET, object_name) + writer = AsyncAppendableObjectWriter( + grpc_client_direct, _ZONAL_BUCKET, object_name + ) await writer.open() await writer.append(object_data) await writer.close(finalize_on_close=True) @@ -749,7 +749,9 @@ def test_mrd_concurrent_download_out_of_bounds( async def _run(): object_data = os.urandom(object_size) - writer = AsyncAppendableObjectWriter(grpc_client_direct, _ZONAL_BUCKET, object_name) + writer = AsyncAppendableObjectWriter( + grpc_client_direct, _ZONAL_BUCKET, object_name + ) await writer.open() await writer.append(object_data) await writer.close(finalize_on_close=True) From 6fe889244ac3d3abc90a99c98cd33b5e29d98d36 Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Wed, 8 Apr 2026 02:44:50 +0000 Subject: [PATCH 11/18] test: add Given/When/Then comments and cleanup test_stream_multiplexer.py --- .../unit/asyncio/test_stream_multiplexer.py | 166 ++++++++++++++---- 1 file changed, 130 insertions(+), 36 deletions(-) diff --git a/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py b/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py index 2549ab5bf107..e2a15b482148 100644 --- a/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py +++ b/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py @@ -28,20 +28,26 @@ class TestSentinelTypes: def test_stream_error_stores_exception_and_generation(self): + # Given an exception and a generation exc = ValueError("test") + + # When a StreamError is created error = _StreamError(exc, generation=3) + + # Then it stores the exception and generation assert error.exception is exc assert error.generation == 3 - def test_stream_end_is_instantiable(self): - sentinel = _StreamEnd() - assert isinstance(sentinel, _StreamEnd) - class TestStreamMultiplexerInit: def test_init_sets_stream_and_defaults(self): + # Given a mock stream mock_stream = AsyncMock() + + # When a multiplexer is created mux = _StreamMultiplexer(mock_stream) + + # Then it sets the stream and defaults assert mux._stream is mock_stream assert mux.stream_generation == 0 assert mux._queues == {} @@ -49,8 +55,13 @@ def test_init_sets_stream_and_defaults(self): assert mux._queue_max_size == _DEFAULT_QUEUE_MAX_SIZE def test_init_custom_queue_size(self): + # Given a mock stream mock_stream = AsyncMock() + + # When a multiplexer is created with a custom queue size mux = _StreamMultiplexer(mock_stream, queue_max_size=50) + + # Then it sets the custom queue size assert mux._queue_max_size == 50 @@ -90,70 +101,98 @@ def _make_multiplexer(self): @pytest.mark.asyncio async def test_register_returns_bounded_queue(self): + # Given a multiplexer mux, _ = self._make_multiplexer() + + # When registering read IDs queue = mux.register({1, 2, 3}) + + # Then a bounded queue is returned assert isinstance(queue, asyncio.Queue) assert queue.maxsize == _DEFAULT_QUEUE_MAX_SIZE - mux.unregister({1, 2, 3}) @pytest.mark.asyncio async def test_register_maps_read_ids_to_same_queue(self): + # Given a multiplexer mux, _ = self._make_multiplexer() + + # When registering multiple read IDs queue = mux.register({10, 20}) + + # Then they map to the same queue assert mux._queues[10] is queue assert mux._queues[20] is queue - mux.unregister({10, 20}) @pytest.mark.asyncio async def test_register_does_not_start_recv_loop(self): + # Given a multiplexer mux, _ = self._make_multiplexer() - assert mux._recv_task is None + + # When registering a read ID mux.register({1}) + + # Then the receive loop is not started assert mux._recv_task is None - mux.unregister({1}) @pytest.mark.asyncio async def test_two_registers_get_separate_queues(self): + # Given a multiplexer mux, _ = self._make_multiplexer() + + # When registering different read IDs separately q1 = mux.register({1}) q2 = mux.register({2}) + + # Then separate queues are returned assert q1 is not q2 assert mux._queues[1] is q1 assert mux._queues[2] is q2 - mux.unregister({1, 2}) @pytest.mark.asyncio async def test_unregister_removes_read_ids(self): + # Given a multiplexer with registered read IDs mux, _ = self._make_multiplexer() mux.register({1, 2}) + + # When unregistering a read ID mux.unregister({1}) + + # Then it is removed from the mapping assert 1 not in mux._queues assert 2 in mux._queues - mux.unregister({2}) @pytest.mark.asyncio async def test_unregister_all_does_not_stop_recv_loop(self): + # Given a multiplexer with an active receive loop mux, _ = self._make_multiplexer() mux.register({1}) mux._ensure_recv_loop() recv_task = mux._recv_task - assert recv_task is not None + + # When unregistering the read ID mux.unregister({1}) + + # Then the receive loop is not cancelled await asyncio.sleep(0) assert not recv_task.cancelled() @pytest.mark.asyncio async def test_unregister_nonexistent_is_noop(self): + # Given a multiplexer with a registered read ID mux, _ = self._make_multiplexer() mux.register({1}) + + # When unregistering a non-existent read ID mux.unregister({999}) + + # Then the existing registration remains assert 1 in mux._queues - mux.unregister({1}) class TestRecvLoop: @pytest.mark.asyncio async def test_routes_response_by_read_id(self): + # Given a multiplexer with registered queues for read IDs 10 and 20 mock_stream = AsyncMock() resp1 = _make_response(read_id=10, data=b"hello") resp2 = _make_response(read_id=20, data=b"world") @@ -162,37 +201,45 @@ async def test_routes_response_by_read_id(self): mux = _StreamMultiplexer(mock_stream) q1 = mux.register({10}) q2 = mux.register({20}) + + # When the receive loop is started mux._ensure_recv_loop() + # Then responses are routed to the corresponding queues and stream ends are sent item1 = await asyncio.wait_for(q1.get(), timeout=1) item2 = await asyncio.wait_for(q2.get(), timeout=1) assert item1 is resp1 assert item2 is resp2 + end1 = await asyncio.wait_for(q1.get(), timeout=1) end2 = await asyncio.wait_for(q2.get(), timeout=1) assert isinstance(end1, _StreamEnd) assert isinstance(end2, _StreamEnd) - mux.unregister({10, 20}) @pytest.mark.asyncio async def test_deduplicates_when_multiple_read_ids_map_to_same_queue(self): + # Given a multiplexer with multiple read IDs mapped to the same queue mock_stream = AsyncMock() resp = _make_multi_range_response([10, 11]) mock_stream.recv = AsyncMock(side_effect=[resp, None]) mux = _StreamMultiplexer(mock_stream) queue = mux.register({10, 11}) + + # When the receive loop is started mux._ensure_recv_loop() + # Then the response is put into the queue only once item = await asyncio.wait_for(queue.get(), timeout=1) assert item is resp + end = await asyncio.wait_for(queue.get(), timeout=1) assert isinstance(end, _StreamEnd) - mux.unregister({10, 11}) @pytest.mark.asyncio async def test_metadata_only_response_broadcast_to_all(self): + # Given a multiplexer with multiple registered queues mock_stream = AsyncMock() metadata_resp = _storage_v2.BidiReadObjectResponse( read_handle=_storage_v2.BidiReadHandle(handle=b"handle") @@ -202,32 +249,38 @@ async def test_metadata_only_response_broadcast_to_all(self): mux = _StreamMultiplexer(mock_stream) q1 = mux.register({10}) q2 = mux.register({20}) + + # When the receive loop is started mux._ensure_recv_loop() + # Then the metadata-only response is broadcast to all queues item1 = await asyncio.wait_for(q1.get(), timeout=1) item2 = await asyncio.wait_for(q2.get(), timeout=1) assert item1 is metadata_resp assert item2 is metadata_resp - mux.unregister({10, 20}) @pytest.mark.asyncio async def test_stream_end_sends_sentinel_to_all_queues(self): + # Given a multiplexer with multiple registered queues and a stream that ends immediately mock_stream = AsyncMock() mock_stream.recv = AsyncMock(return_value=None) mux = _StreamMultiplexer(mock_stream) q1 = mux.register({10}) q2 = mux.register({20}) + + # When the receive loop is started mux._ensure_recv_loop() + # Then a StreamEnd sentinel is sent to all queues end1 = await asyncio.wait_for(q1.get(), timeout=1) end2 = await asyncio.wait_for(q2.get(), timeout=1) assert isinstance(end1, _StreamEnd) assert isinstance(end2, _StreamEnd) - mux.unregister({10, 20}) @pytest.mark.asyncio async def test_error_broadcasts_stream_error_to_all_queues(self): + # Given a multiplexer with multiple registered queues and a stream that raises an error mock_stream = AsyncMock() exc = RuntimeError("stream broke") mock_stream.recv = AsyncMock(side_effect=exc) @@ -235,10 +288,12 @@ async def test_error_broadcasts_stream_error_to_all_queues(self): mux = _StreamMultiplexer(mock_stream) q1 = mux.register({10}) q2 = mux.register({20}) - mux._ensure_recv_loop() + # When the receive loop is started + mux._ensure_recv_loop() await asyncio.sleep(0.05) + # Then a StreamError is broadcast to all queues err1 = q1.get_nowait() err2 = q2.get_nowait() assert isinstance(err1, _StreamError) @@ -246,10 +301,10 @@ async def test_error_broadcasts_stream_error_to_all_queues(self): assert err1.generation == 0 assert isinstance(err2, _StreamError) assert err2.exception is exc - mux.unregister({10, 20}) @pytest.mark.asyncio async def test_error_uses_put_nowait(self): + # Given a multiplexer with a full queue and a stream that raises an error mock_stream = AsyncMock() exc = RuntimeError("broke") mock_stream.recv = AsyncMock(side_effect=exc) @@ -257,67 +312,80 @@ async def test_error_uses_put_nowait(self): mux = _StreamMultiplexer(mock_stream, queue_max_size=1) queue = mux.register({10}) queue.put_nowait("filler") - mux._ensure_recv_loop() + # When the receive loop is started + mux._ensure_recv_loop() await asyncio.sleep(0.05) - # Queue is full (maxsize=1), but _put_error_nowait pops existing items - # to ensure the error gets recorded. + # Then the error is recorded even if the queue was full assert queue.qsize() == 1 err = queue.get_nowait() assert isinstance(err, _StreamError) assert err.exception is exc - mux.unregister({10}) @pytest.mark.asyncio async def test_unknown_read_id_is_dropped(self): + # Given a multiplexer and a response with an unknown read ID mock_stream = AsyncMock() resp = _make_response(read_id=999) mock_stream.recv = AsyncMock(side_effect=[resp, None]) mux = _StreamMultiplexer(mock_stream) queue = mux.register({10}) + + # When the receive loop is started mux._ensure_recv_loop() + # Then the response is dropped and only StreamEnd is received end = await asyncio.wait_for(queue.get(), timeout=1) assert isinstance(end, _StreamEnd) - mux.unregister({10}) class TestSend: @pytest.mark.asyncio async def test_send_forwards_to_stream(self): + # Given a multiplexer and a request mock_stream = AsyncMock() mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) mux = _StreamMultiplexer(mock_stream) - request = _storage_v2.BidiReadObjectRequest( read_ranges=[ _storage_v2.ReadRange(read_id=1, read_offset=0, read_length=10) ] ) + + # When sending the request gen = await mux.send(request) + + # Then it is forwarded to the stream and current generation is returned mock_stream.send.assert_called_once_with(request) assert gen == 0 @pytest.mark.asyncio async def test_send_returns_current_generation(self): + # Given a multiplexer at generation 5 mock_stream = AsyncMock() mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) mux = _StreamMultiplexer(mock_stream) mux._stream_generation = 5 - request = _storage_v2.BidiReadObjectRequest() + + # When sending a request gen = await mux.send(request) + + # Then it returns the current generation assert gen == 5 @pytest.mark.asyncio async def test_send_propagates_exception(self): + # Given a multiplexer where send fails mock_stream = AsyncMock() mock_stream.send = AsyncMock(side_effect=RuntimeError("send failed")) mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) mux = _StreamMultiplexer(mock_stream) + # When sending a request + # Then the exception is propagated with pytest.raises(RuntimeError, match="send failed"): await mux.send(_storage_v2.BidiReadObjectRequest()) @@ -325,25 +393,27 @@ async def test_send_propagates_exception(self): class TestReopenStream: @pytest.mark.asyncio async def test_reopen_bumps_generation_and_replaces_stream(self): + # Given a multiplexer with a registered queue old_stream = AsyncMock() old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) mux = _StreamMultiplexer(old_stream) mux.register({1}) - assert mux.stream_generation == 0 new_stream = AsyncMock() new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) factory = AsyncMock(return_value=new_stream) + # When the stream is reopened with the correct generation await mux.reopen_stream(0, factory) + # Then the generation is bumped and the stream is replaced assert mux.stream_generation == 1 assert mux._stream is new_stream factory.assert_called_once() - mux.unregister({1}) @pytest.mark.asyncio async def test_reopen_skips_if_generation_mismatch(self): + # Given a multiplexer at generation 5 mock_stream = AsyncMock() mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) mux = _StreamMultiplexer(mock_stream) @@ -351,14 +421,17 @@ async def test_reopen_skips_if_generation_mismatch(self): mux.register({1}) factory = AsyncMock() + + # When reopen is called with a mismatched generation (3) await mux.reopen_stream(3, factory) + # Then the reopen is skipped and generation remains unchanged assert mux.stream_generation == 5 factory.assert_not_called() - mux.unregister({1}) @pytest.mark.asyncio async def test_reopen_broadcasts_error_before_bump(self): + # Given a multiplexer with a registered queue old_stream = AsyncMock() old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) mux = _StreamMultiplexer(old_stream) @@ -368,15 +441,17 @@ async def test_reopen_broadcasts_error_before_bump(self): new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) factory = AsyncMock(return_value=new_stream) + # When the stream is reopened await mux.reopen_stream(0, factory) + # Then a StreamError is broadcast to the queue before the bump err = queue.get_nowait() assert isinstance(err, _StreamError) assert err.generation == 0 - mux.unregister({1}) @pytest.mark.asyncio async def test_reopen_starts_new_recv_loop(self): + # Given a multiplexer with a registered queue and an active recv task old_stream = AsyncMock() old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) mux = _StreamMultiplexer(old_stream) @@ -387,14 +462,16 @@ async def test_reopen_starts_new_recv_loop(self): new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) factory = AsyncMock(return_value=new_stream) + # When the stream is reopened await mux.reopen_stream(0, factory) + # Then a new receive loop task is started assert mux._recv_task is not old_recv_task assert not mux._recv_task.done() - mux.unregister({1}) @pytest.mark.asyncio async def test_reopen_closes_old_stream_best_effort(self): + # Given a multiplexer where closing the old stream raises an error old_stream = AsyncMock() old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) old_stream.close = AsyncMock(side_effect=RuntimeError("close failed")) @@ -405,12 +482,15 @@ async def test_reopen_closes_old_stream_best_effort(self): new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) factory = AsyncMock(return_value=new_stream) + # When the stream is reopened await mux.reopen_stream(0, factory) + + # Then the reopen still succeeds assert mux.stream_generation == 1 - mux.unregister({1}) @pytest.mark.asyncio async def test_concurrent_reopen_only_one_wins(self): + # Given a multiplexer and a counting factory old_stream = AsyncMock() old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) mux = _StreamMultiplexer(old_stream) @@ -425,14 +505,15 @@ async def counting_factory(): new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) return new_stream + # When concurrent reopen calls are made await asyncio.gather( mux.reopen_stream(0, counting_factory), mux.reopen_stream(0, counting_factory), ) + # Then only one factory call is made and generation is bumped once assert call_count == 1 assert mux.stream_generation == 1 - mux.unregister({1}) @pytest.mark.asyncio async def test_reopen_factory_failure_leaves_generation_unchanged(self): @@ -440,6 +521,7 @@ async def test_reopen_factory_failure_leaves_generation_unchanged(self): is not restarted. The caller's retry manager will re-attempt reopen with the same generation, which will succeed because the generation check still matches.""" + # Given a multiplexer and a failing factory old_stream = AsyncMock() old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) mux = _StreamMultiplexer(old_stream) @@ -447,30 +529,32 @@ async def test_reopen_factory_failure_leaves_generation_unchanged(self): failing_factory = AsyncMock(side_effect=RuntimeError("open failed")) + # When reopen fails with pytest.raises(RuntimeError, match="open failed"): await mux.reopen_stream(0, failing_factory) - # Generation was NOT bumped + # Then generation is NOT bumped and recv loop is stopped assert mux.stream_generation == 0 - # Recv loop was stopped and not restarted assert mux._recv_task is None or mux._recv_task.done() - # A subsequent reopen with the same generation succeeds + # Given a subsequent successful reopen new_stream = AsyncMock() new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) ok_factory = AsyncMock(return_value=new_stream) + # When reopen is called again with the same generation await mux.reopen_stream(0, ok_factory) + # Then it succeeds assert mux.stream_generation == 1 assert mux._stream is new_stream assert mux._recv_task is not None and not mux._recv_task.done() - mux.unregister({1}) class TestClose: @pytest.mark.asyncio async def test_close_cancels_recv_loop(self): + # Given a multiplexer with an active receive loop mock_stream = AsyncMock() mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) mux = _StreamMultiplexer(mock_stream) @@ -478,19 +562,25 @@ async def test_close_cancels_recv_loop(self): mux._ensure_recv_loop() recv_task = mux._recv_task + # When closing the multiplexer await mux.close() + + # Then the receive loop task is cancelled assert recv_task.cancelled() @pytest.mark.asyncio async def test_close_broadcasts_terminal_error(self): + # Given a multiplexer with registered queues mock_stream = AsyncMock() mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) mux = _StreamMultiplexer(mock_stream) q1 = mux.register({1}) q2 = mux.register({2}) + # When closing the multiplexer await mux.close() + # Then a terminal StreamError is broadcast to all queues err1 = q1.get_nowait() err2 = q2.get_nowait() assert isinstance(err1, _StreamError) @@ -498,6 +588,10 @@ async def test_close_broadcasts_terminal_error(self): @pytest.mark.asyncio async def test_close_with_no_tasks_is_noop(self): + # Given a multiplexer with no active tasks mock_stream = AsyncMock() mux = _StreamMultiplexer(mock_stream) + + # When closing the multiplexer + # Then it should not raise any error await mux.close() # should not raise From 18c321e67d5c188203c77df5cd2a718ad5c4f42a Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Wed, 8 Apr 2026 02:51:23 +0000 Subject: [PATCH 12/18] test: cleanup redundant mock setups in test_async_multi_range_downloader.py --- .../tests/unit/asyncio/test_async_multi_range_downloader.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py b/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py index 19847db015eb..2011c52a7fe1 100644 --- a/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py +++ b/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py @@ -60,9 +60,6 @@ async def _make_mock_mrd( mock_stream.generation_number = _TEST_GENERATION_NUMBER mock_stream.persisted_size = _TEST_OBJECT_SIZE mock_stream.read_handle = _TEST_READ_HANDLE - mock_stream.is_stream_open = True - # Default recv blocks forever (tests override with specific side_effect) - mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) mrd = await AsyncMultiRangeDownloader.create_mrd( mock_client, bucket_name, object_name, generation, read_handle @@ -422,8 +419,6 @@ async def test_create_mrd_with_generation_number( mock_stream.generation_number = _TEST_GENERATION_NUMBER mock_stream.persisted_size = _TEST_OBJECT_SIZE mock_stream.read_handle = _TEST_READ_HANDLE - mock_stream.is_stream_open = True - mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) # Act mrd = await AsyncMultiRangeDownloader.create_mrd( @@ -561,7 +556,6 @@ async def staged_recv(): return None mock_mrd.read_obj_str.recv = AsyncMock(side_effect=staged_recv) - mock_mrd.read_obj_str.is_stream_open = True mock_random_int.return_value = 123 From 236ef04a91acbf1f48e006a2a09a905d033c2a0a Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Wed, 8 Apr 2026 03:14:56 +0000 Subject: [PATCH 13/18] chore: add logging to _stream_multiplexer.py recv loop failure --- .../google/cloud/storage/asyncio/_stream_multiplexer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py index caabb5b8f089..d8c569c6879a 100644 --- a/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py +++ b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py @@ -154,6 +154,7 @@ async def _recv_loop(self) -> None: except asyncio.CancelledError: raise except Exception as e: + logger.warning(f"Stream multiplexer recv loop failed: {e}", exc_info=True) error = _StreamError(e, self._stream_generation) for queue in self._get_unique_queues(): self._put_error_nowait(queue, error) From 311c0539487a3db1bc8df0f57d0a8663fbc95a48 Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Wed, 8 Apr 2026 03:15:40 +0000 Subject: [PATCH 14/18] test: add more assertions to AsyncMultiRangeDownloader creation test --- .../tests/unit/asyncio/test_async_multi_range_downloader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py b/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py index 2011c52a7fe1..3b8d281688e8 100644 --- a/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py +++ b/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py @@ -431,6 +431,8 @@ async def test_create_mrd_with_generation_number( # Assert assert mrd.generation == _TEST_GENERATION_NUMBER + assert mrd.read_handle == _TEST_READ_HANDLE + assert mrd.persisted_size == _TEST_OBJECT_SIZE assert "'generation_number' is deprecated" in caplog.text @pytest.mark.asyncio From cde5c207de5d5e01a6b89a0387249a4c504504c1 Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Wed, 8 Apr 2026 03:35:32 +0000 Subject: [PATCH 15/18] test: refactor test_async_multi_range_downloader.py with Arrange/Act/Assert --- .../test_async_multi_range_downloader.py | 60 +++++++++++++------ 1 file changed, 41 insertions(+), 19 deletions(-) diff --git a/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py b/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py index 3b8d281688e8..1ac17ad3cc8e 100644 --- a/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py +++ b/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py @@ -185,6 +185,7 @@ async def controlled_recv(): async def test_download_ranges( self, mock_cls_async_read_object_stream, mock_random_int ): + # Arrange data = b"these_are_18_chars" crc32c = Checksum(data).digest() crc32c_int = int.from_bytes(crc32c, "big") @@ -193,27 +194,36 @@ async def test_download_ranges( mock_random_int.side_effect = [456] mock_mrd.read_obj_str.send = AsyncMock() - mock_mrd.read_obj_str.recv = AsyncMock( - side_effect=[ - _storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - _storage_v2.ObjectRangeData( - checksummed_data=_storage_v2.ChecksummedData( - content=data, crc32c=crc32c_int - ), - range_end=True, - read_range=_storage_v2.ReadRange( - read_offset=0, read_length=18, read_id=456 - ), - ) - ], - ), - None, - ] - ) - + mock_mrd.read_obj_str.recv = AsyncMock() + mock_mrd.read_obj_str.recv.side_effect = [ + _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=data, crc32c=crc32c_int + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=0, read_length=18, read_id=456 + ), + ) + ], + ), + None, + ] + # Act buffer = BytesIO() + await mock_mrd.download_ranges([(0, 18, buffer)]) + + # Assert + mock_mrd.read_obj_str.send.assert_called_once_with( + _storage_v2.BidiReadObjectRequest( + read_ranges=[ + _storage_v2.ReadRange(read_offset=0, read_length=18, read_id=456) + ] + ) + ) assert buffer.getvalue() == data @pytest.mark.asyncio @@ -523,6 +533,7 @@ async def test_on_open_error_logs_warning(self, mock_logger): async def test_download_ranges_resumption_logging( self, mock_cls_async_read_object_stream, mock_random_int, mock_logger ): + # Arrange mock_mrd, _ = await self._make_mock_mrd(mock_cls_async_read_object_stream) from google.api_core import exceptions as core_exceptions @@ -561,11 +572,22 @@ async def staged_recv(): mock_random_int.return_value = 123 + # Act buffer = BytesIO() + + # Patch Checksum where it is likely used (reads_resumption_strategy or similar), + # but actually if we use google_crc32c directly, we should patch that or provide valid CRC. + # Since we can't reliably predict where Checksum is imported/used without more digging, + # let's provide a valid CRC for b"data". + # Checksum(b"data").digest() -> needs to match crc32c=123. + # But we can't force b"data" to have crc=123. + # So we MUST patch Checksum. + # It is used in google.cloud.storage.asyncio.retry.reads_resumption_strategy with mock.patch( "google.cloud.storage.asyncio.retry.reads_resumption_strategy.Checksum" ) as mock_chk: mock_chk.return_value.digest.return_value = (123).to_bytes(4, "big") await mock_mrd.download_ranges([(0, 4, buffer)]) + # Assert mock_logger.info.assert_any_call("Resuming download (attempt 2) for 1 ranges.") From eabf700cd59ecd6d00c6ae0a9f26076274383ab9 Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Wed, 8 Apr 2026 04:32:42 +0000 Subject: [PATCH 16/18] test: refactor test_mrd_concurrent_download_out_of_bounds in test_zonal.py --- .../tests/system/test_zonal.py | 26 ++++++------------- 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/packages/google-cloud-storage/tests/system/test_zonal.py b/packages/google-cloud-storage/tests/system/test_zonal.py index 3c9e7583a46a..3c275794e331 100644 --- a/packages/google-cloud-storage/tests/system/test_zonal.py +++ b/packages/google-cloud-storage/tests/system/test_zonal.py @@ -759,34 +759,24 @@ async def _run(): async with AsyncMultiRangeDownloader( grpc_client_direct, _ZONAL_BUCKET, object_name ) as mrd: - b_valid = BytesIO() - t_valid = asyncio.create_task(mrd.download_ranges([(0, 100, b_valid)])) - - b_oob1 = BytesIO() - t_oob1 = asyncio.create_task( - mrd.download_ranges([(object_size + 1000, 100, b_oob1)]) + valid_buffer = BytesIO() + valid_task = asyncio.create_task( + mrd.download_ranges([(0, 100, valid_buffer)]) ) - # EOF ask for 100 bytes - b_oob2 = BytesIO() - t_oob2 = asyncio.create_task( - mrd.download_ranges([(object_size, 100, b_oob2)]) + oob_buffer = BytesIO() + oob_task = asyncio.create_task( + mrd.download_ranges([(object_size + 1000, 100, oob_buffer)]) ) - results = await asyncio.gather( - t_valid, t_oob1, t_oob2, return_exceptions=True - ) + results = await asyncio.gather(valid_task, oob_task, return_exceptions=True) # Verify valid one processed correctly - assert b_valid.getvalue() == object_data[:100] + assert valid_buffer.getvalue() == object_data[:100] # Verify fully OOB request returned Exception assert isinstance(results[1], OutOfRange) - # Verify request exactly at EOF successfully completed with 0 bytes - assert not isinstance(results[2], Exception) - assert b_oob2.getvalue() == b"" - del writer gc.collect() blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) From 80914905a8069adc92643bacde17bc51766f10f1 Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Wed, 8 Apr 2026 14:48:45 +0000 Subject: [PATCH 17/18] fix: update stream end condition to use grpc.aio.EOF in _stream_multiplexer --- .../cloud/storage/asyncio/_stream_multiplexer.py | 3 ++- .../tests/unit/asyncio/test_stream_multiplexer.py | 11 ++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py index d8c569c6879a..fd714b572d82 100644 --- a/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py +++ b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py @@ -15,6 +15,7 @@ from __future__ import annotations import asyncio +import grpc import logging from typing import Awaitable, Callable, Dict, Optional, Set @@ -121,7 +122,7 @@ async def _recv_loop(self) -> None: try: while True: response = await self._stream.recv() - if response is None: + if response == grpc.aio.EOF: sentinel = _StreamEnd() await asyncio.gather( *( diff --git a/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py b/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py index e2a15b482148..08d3f678f118 100644 --- a/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py +++ b/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import grpc from unittest.mock import AsyncMock import pytest @@ -196,7 +197,7 @@ async def test_routes_response_by_read_id(self): mock_stream = AsyncMock() resp1 = _make_response(read_id=10, data=b"hello") resp2 = _make_response(read_id=20, data=b"world") - mock_stream.recv = AsyncMock(side_effect=[resp1, resp2, None]) + mock_stream.recv = AsyncMock(side_effect=[resp1, resp2, grpc.aio.EOF]) mux = _StreamMultiplexer(mock_stream) q1 = mux.register({10}) @@ -222,7 +223,7 @@ async def test_deduplicates_when_multiple_read_ids_map_to_same_queue(self): # Given a multiplexer with multiple read IDs mapped to the same queue mock_stream = AsyncMock() resp = _make_multi_range_response([10, 11]) - mock_stream.recv = AsyncMock(side_effect=[resp, None]) + mock_stream.recv = AsyncMock(side_effect=[resp, grpc.aio.EOF]) mux = _StreamMultiplexer(mock_stream) queue = mux.register({10, 11}) @@ -244,7 +245,7 @@ async def test_metadata_only_response_broadcast_to_all(self): metadata_resp = _storage_v2.BidiReadObjectResponse( read_handle=_storage_v2.BidiReadHandle(handle=b"handle") ) - mock_stream.recv = AsyncMock(side_effect=[metadata_resp, None]) + mock_stream.recv = AsyncMock(side_effect=[metadata_resp, grpc.aio.EOF]) mux = _StreamMultiplexer(mock_stream) q1 = mux.register({10}) @@ -263,7 +264,7 @@ async def test_metadata_only_response_broadcast_to_all(self): async def test_stream_end_sends_sentinel_to_all_queues(self): # Given a multiplexer with multiple registered queues and a stream that ends immediately mock_stream = AsyncMock() - mock_stream.recv = AsyncMock(return_value=None) + mock_stream.recv = AsyncMock(return_value=grpc.aio.EOF) mux = _StreamMultiplexer(mock_stream) q1 = mux.register({10}) @@ -328,7 +329,7 @@ async def test_unknown_read_id_is_dropped(self): # Given a multiplexer and a response with an unknown read ID mock_stream = AsyncMock() resp = _make_response(read_id=999) - mock_stream.recv = AsyncMock(side_effect=[resp, None]) + mock_stream.recv = AsyncMock(side_effect=[resp, grpc.aio.EOF]) mux = _StreamMultiplexer(mock_stream) queue = mux.register({10}) From d7b3f5445a7d732a6bcd0d825fe84b828594f4f9 Mon Sep 17 00:00:00 2001 From: zhixiangli Date: Thu, 9 Apr 2026 11:13:18 +0000 Subject: [PATCH 18/18] fix: log warning for unregistered read_id in stream multiplexer Added a warning log when the stream multiplexer receives data for an unregistered read_id, and updated the corresponding test to verify the logging behavior. Also applied linter formatting to the modified files. --- .../google/cloud/storage/asyncio/_stream_multiplexer.py | 7 ++++++- .../tests/unit/asyncio/test_stream_multiplexer.py | 7 +++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py index fd714b572d82..3bcac2ab9cfb 100644 --- a/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py +++ b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py @@ -15,10 +15,11 @@ from __future__ import annotations import asyncio -import grpc import logging from typing import Awaitable, Callable, Dict, Optional, Set +import grpc + from google.cloud import _storage_v2 from google.cloud.storage.asyncio.async_read_object_stream import ( _AsyncReadObjectStream, @@ -139,6 +140,10 @@ async def _recv_loop(self) -> None: queue = self._queues.get(read_id) if queue: queues_to_notify.add(queue) + else: + logger.warning( + f"Received data for unregistered read_id: {read_id}" + ) await asyncio.gather( *( self._put_with_timeout(queue, response) diff --git a/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py b/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py index 08d3f678f118..9c75b9e8e1e1 100644 --- a/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py +++ b/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py @@ -13,9 +13,9 @@ # limitations under the License. import asyncio -import grpc from unittest.mock import AsyncMock +import grpc import pytest from google.cloud import _storage_v2 @@ -325,7 +325,7 @@ async def test_error_uses_put_nowait(self): assert err.exception is exc @pytest.mark.asyncio - async def test_unknown_read_id_is_dropped(self): + async def test_unknown_read_id_is_dropped(self, caplog): # Given a multiplexer and a response with an unknown read ID mock_stream = AsyncMock() resp = _make_response(read_id=999) @@ -341,6 +341,9 @@ async def test_unknown_read_id_is_dropped(self): end = await asyncio.wait_for(queue.get(), timeout=1) assert isinstance(end, _StreamEnd) + # And a warning is logged + assert "Received data for unregistered read_id: 999" in caplog.text + class TestSend: @pytest.mark.asyncio