diff --git a/packages/google-cloud-storage/google/cloud/storage/asyncio/async_read_object_stream.py b/packages/google-cloud-storage/google/cloud/storage/asyncio/async_read_object_stream.py index cd7ae067c631..8fd98d623571 100644 --- a/packages/google-cloud-storage/google/cloud/storage/asyncio/async_read_object_stream.py +++ b/packages/google-cloud-storage/google/cloud/storage/asyncio/async_read_object_stream.py @@ -79,6 +79,9 @@ def __init__( self.socket_like_rpc: Optional[AsyncBidiRpc] = None self._is_stream_open: bool = False self.persisted_size: Optional[int] = None + self.is_finalized: bool = False + self.full_obj_server_crc32c: Optional[int] = None + self.object_metadata: Optional[_storage_v2.Object] = None async def open(self, metadata: Optional[List[Tuple[str, str]]] = None) -> None: """Opens the bidi-gRPC connection to read from the object. @@ -132,6 +135,18 @@ async def open(self, metadata: Optional[List[Tuple[str, str]]] = None) -> None: self.generation_number = response.metadata.generation # update persisted size self.persisted_size = response.metadata.size + self.object_metadata = response.metadata + if ( + hasattr(response.metadata, "finalize_time") + and response.metadata.finalize_time + and response.metadata.finalize_time.second > 0 + ): + self.is_finalized = True + if ( + hasattr(response.metadata, "checksums") + and response.metadata.checksums + ): + self.full_obj_server_crc32c = response.metadata.checksums.crc32c if response and response.read_handle: self.read_handle = response.read_handle diff --git a/packages/google-cloud-storage/google/cloud/storage/asyncio/retry/reads_resumption_strategy.py b/packages/google-cloud-storage/google/cloud/storage/asyncio/retry/reads_resumption_strategy.py index 845770c3a215..6cf17af19089 100644 --- a/packages/google-cloud-storage/google/cloud/storage/asyncio/retry/reads_resumption_strategy.py +++ b/packages/google-cloud-storage/google/cloud/storage/asyncio/retry/reads_resumption_strategy.py @@ -36,7 +36,12 @@ class _DownloadState: """A helper class to track the state of a single range download.""" def __init__( - self, initial_offset: int, initial_length: int, user_buffer: IO[bytes] + self, + initial_offset: int, + initial_length: int, + user_buffer: IO[bytes], + is_full_object_read: bool = False, + enable_checksum: bool = True, ): self.initial_offset = initial_offset self.initial_length = initial_length @@ -44,6 +49,12 @@ def __init__( self.bytes_written = 0 self.next_expected_offset = initial_offset self.is_complete = False + self.is_full_object_read = is_full_object_read + self.rolling_checksum = ( + google_crc32c.Checksum() + if (is_full_object_read and enable_checksum) + else None + ) class _ReadResumptionStrategy(_BaseResumptionStrategy): @@ -90,6 +101,7 @@ def update_state_from_response( ) download_states = state["download_states"] + checksum_enabled = state.get("enable_checksum", True) for object_data_range in proto.object_data_ranges: # Ignore empty ranges or ranges for IDs not in our state @@ -125,7 +137,7 @@ def update_state_from_response( checksummed_data = object_data_range.checksummed_data data = checksummed_data.content - if checksummed_data.HasField("crc32c"): + if checksum_enabled and checksummed_data.HasField("crc32c"): server_checksum = checksummed_data.crc32c client_checksum = google_crc32c.value(data) if server_checksum != client_checksum: @@ -138,10 +150,14 @@ def update_state_from_response( # Update State & Write Data chunk_size = len(data) read_state.user_buffer.write(data) + + # Commit updates only after the write succeeds + if checksum_enabled and read_state.rolling_checksum is not None: + read_state.rolling_checksum.update(data) read_state.bytes_written += chunk_size read_state.next_expected_offset += chunk_size - # Final Byte Count Verification + # Final Byte Count & Full Object Checksum Verification if object_data_range.range_end: read_state.is_complete = True if ( @@ -154,6 +170,26 @@ def update_state_from_response( f"Expected {read_state.initial_length}, got {read_state.bytes_written}", ) + # Perform full-object checksum verification once the stream finishes. + if ( + read_state.is_full_object_read + and checksum_enabled + and read_state.rolling_checksum is not None + ): + full_obj_server_crc32c = state.get("full_obj_server_crc32c") + if full_obj_server_crc32c is not None: + # Use standard big-endian byte conversion to retrieve the rolling checksum value. + client_checksum = int.from_bytes( + read_state.rolling_checksum.digest(), + byteorder="big", + ) + if client_checksum != full_obj_server_crc32c: + raise DataCorruption( + response, + f"Full object checksum mismatch for read_id {read_id}. " + f"Server authoritative crc32c: {full_obj_server_crc32c}, client calculated rolling: {client_checksum}.", + ) + async def recover_state_on_failure(self, error: Exception, state: Any) -> None: """Handles BidiReadObjectRedirectedError for reads.""" routing_token, read_handle = _handle_redirect(error) diff --git a/packages/google-cloud-storage/tests/unit/asyncio/retry/test_reads_resumption_strategy.py b/packages/google-cloud-storage/tests/unit/asyncio/retry/test_reads_resumption_strategy.py index dc27cb701974..841ea655626e 100644 --- a/packages/google-cloud-storage/tests/unit/asyncio/retry/test_reads_resumption_strategy.py +++ b/packages/google-cloud-storage/tests/unit/asyncio/retry/test_reads_resumption_strategy.py @@ -45,6 +45,48 @@ def test_initialization(self): self.assertEqual(state.bytes_written, 0) self.assertEqual(state.next_expected_offset, initial_offset) self.assertFalse(state.is_complete) + self.assertFalse(state.is_full_object_read) + self.assertIsNone(state.rolling_checksum) + + def test_initialization_with_full_object_read(self): + """Test that _DownloadState initializes correctly when is_full_object_read is True.""" + initial_offset = 10 + initial_length = 100 + user_buffer = io.BytesIO() + state_full = _DownloadState( + initial_offset, initial_length, user_buffer, is_full_object_read=True + ) + + self.assertEqual(state_full.initial_offset, initial_offset) + self.assertEqual(state_full.initial_length, initial_length) + self.assertEqual(state_full.user_buffer, user_buffer) + self.assertEqual(state_full.bytes_written, 0) + self.assertEqual(state_full.next_expected_offset, initial_offset) + self.assertFalse(state_full.is_complete) + self.assertTrue(state_full.is_full_object_read) + self.assertIsNotNone(state_full.rolling_checksum) + + def test_initialization_with_full_object_read_and_checksum_disabled(self): + """Test that _DownloadState does not initialize rolling_checksum when enable_checksum is False.""" + initial_offset = 10 + initial_length = 100 + user_buffer = io.BytesIO() + state_full = _DownloadState( + initial_offset, + initial_length, + user_buffer, + is_full_object_read=True, + enable_checksum=False, + ) + + self.assertEqual(state_full.initial_offset, initial_offset) + self.assertEqual(state_full.initial_length, initial_length) + self.assertEqual(state_full.user_buffer, user_buffer) + self.assertEqual(state_full.bytes_written, 0) + self.assertEqual(state_full.next_expected_offset, initial_offset) + self.assertFalse(state_full.is_complete) + self.assertTrue(state_full.is_full_object_read) + self.assertIsNone(state_full.rolling_checksum) class TestReadResumptionStrategy(unittest.TestCase): @@ -53,12 +95,24 @@ def setUp(self): self.state = {"download_states": {}, "read_handle": None, "routing_token": None} - def _add_download(self, read_id, offset=0, length=100, buffer=None): + def _add_download( + self, + read_id, + offset=0, + length=100, + buffer=None, + is_full_object_read=False, + enable_checksum=True, + ): """Helper to inject a download state into the correct nested location.""" if buffer is None: buffer = io.BytesIO() state = _DownloadState( - initial_offset=offset, initial_length=length, user_buffer=buffer + initial_offset=offset, + initial_length=length, + user_buffer=buffer, + is_full_object_read=is_full_object_read, + enable_checksum=enable_checksum, ) self.state["download_states"][read_id] = state return state @@ -358,3 +412,61 @@ async def run(): # Token should remain unchanged self.assertEqual(self.state["routing_token"], "existing-token") + + def test_update_state_full_object_checksum_success(self): + """Test that full object checksum verification succeeds on range_end.""" + read_state = self._add_download( + _READ_ID, offset=0, length=9, is_full_object_read=True + ) + self.state["enable_checksum"] = True + self.state["full_obj_server_crc32c"] = google_crc32c.value(b"testdata1") + + resp1 = self._create_response(b"test", _READ_ID, offset=0) + self.strategy.update_state_from_response(resp1, self.state) + + resp2 = self._create_response(b"data1", _READ_ID, offset=4, range_end=True) + self.strategy.update_state_from_response(resp2, self.state) + + self.assertTrue(read_state.is_complete) + self.assertEqual(read_state.bytes_written, 9) + + def test_update_state_full_object_checksum_failure(self): + """Test that full object checksum verification raises DataCorruption on mismatch at range_end.""" + self._add_download(_READ_ID, offset=0, length=9, is_full_object_read=True) + self.state["enable_checksum"] = True + self.state["full_obj_server_crc32c"] = 111111 # Wrong server checksum! + + resp1 = self._create_response(b"test", _READ_ID, offset=0) + self.strategy.update_state_from_response(resp1, self.state) + + resp2 = self._create_response(b"data1", _READ_ID, offset=4, range_end=True) + with self.assertRaisesRegex(DataCorruption, "Full object checksum mismatch"): + self.strategy.update_state_from_response(resp2, self.state) + + def test_update_state_checksum_mismatch_ignored_when_disabled(self): + """Test that a CRC32C mismatch is ignored when enable_checksum is False.""" + self._add_download(_READ_ID) + self.state["enable_checksum"] = False + response = self._create_response(b"data", _READ_ID, offset=0, crc=999999) + + # Should NOT raise DataCorruption! + self.strategy.update_state_from_response(response, self.state) + + def test_update_state_full_object_checksum_mismatch_ignored_when_disabled(self): + """Test that a full-object CRC32C mismatch is ignored when enable_checksum is False.""" + self._add_download( + _READ_ID, + offset=0, + length=9, + is_full_object_read=True, + enable_checksum=False, + ) + self.state["enable_checksum"] = False + self.state["full_obj_server_crc32c"] = 111111 # Wrong server checksum! + + resp1 = self._create_response(b"test", _READ_ID, offset=0) + self.strategy.update_state_from_response(resp1, self.state) + + resp2 = self._create_response(b"data1", _READ_ID, offset=4, range_end=True) + # Should NOT raise DataCorruption! + self.strategy.update_state_from_response(resp2, self.state) diff --git a/packages/google-cloud-storage/tests/unit/asyncio/test_async_read_object_stream.py b/packages/google-cloud-storage/tests/unit/asyncio/test_async_read_object_stream.py index f5783be6bf94..a8f64422765e 100644 --- a/packages/google-cloud-storage/tests/unit/asyncio/test_async_read_object_stream.py +++ b/packages/google-cloud-storage/tests/unit/asyncio/test_async_read_object_stream.py @@ -38,9 +38,11 @@ async def instantiate_read_obj_stream(mock_client, mock_cls_async_bidi_rpc, open socket_like_rpc.open = AsyncMock() recv_response = mock.MagicMock(spec=_storage_v2.BidiReadObjectResponse) - recv_response.metadata = mock.MagicMock(spec=_storage_v2.Object) + recv_response.metadata = mock.MagicMock() recv_response.metadata.generation = _TEST_GENERATION_NUMBER recv_response.metadata.size = _TEST_OBJECT_SIZE + recv_response.metadata.finalize_time.second = 30 + recv_response.metadata.checksums.crc32c = 98765 recv_response.read_handle = _TEST_READ_HANDLE socket_like_rpc.recv = AsyncMock(return_value=recv_response) @@ -130,6 +132,8 @@ async def test_open(mock_client, mock_cls_async_bidi_rpc): assert read_obj_stream.generation_number == _TEST_GENERATION_NUMBER assert read_obj_stream.read_handle == _TEST_READ_HANDLE assert read_obj_stream.persisted_size == _TEST_OBJECT_SIZE + assert read_obj_stream.is_finalized is True + assert read_obj_stream.full_obj_server_crc32c == 98765 assert read_obj_stream.is_stream_open @@ -381,3 +385,36 @@ async def test_recv_updates_read_handle_on_refresh( await stream.recv() assert stream.read_handle == refreshed_handle + + +@mock.patch("google.cloud.storage.asyncio.async_read_object_stream.AsyncBidiRpc") +@mock.patch( + "google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" +) +@pytest.mark.asyncio +async def test_open_unfinalized_object_skips_checksum( + mock_client, mock_cls_async_bidi_rpc +): + socket_like_rpc = AsyncMock() + mock_cls_async_bidi_rpc.return_value = socket_like_rpc + socket_like_rpc.open = AsyncMock() + + recv_response = mock.MagicMock(spec=_storage_v2.BidiReadObjectResponse) + recv_response.metadata = mock.MagicMock() + recv_response.metadata.generation = _TEST_GENERATION_NUMBER + recv_response.metadata.size = _TEST_OBJECT_SIZE + recv_response.metadata.finalize_time.second = 0 # NOT finalized! + recv_response.metadata.checksums.crc32c = 98765 + recv_response.read_handle = _TEST_READ_HANDLE + socket_like_rpc.recv = AsyncMock(return_value=recv_response) + + read_obj_stream = _AsyncReadObjectStream( + client=mock_client, + bucket_name=_TEST_BUCKET_NAME, + object_name=_TEST_OBJECT_NAME, + ) + + await read_obj_stream.open() + + assert read_obj_stream.is_finalized is False + assert read_obj_stream.full_obj_server_crc32c is None diff --git a/packages/google-cloud-storage/tests/unit/conftest.py b/packages/google-cloud-storage/tests/unit/conftest.py new file mode 100644 index 000000000000..2eeabdc990e6 --- /dev/null +++ b/packages/google-cloud-storage/tests/unit/conftest.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import pytest + + +@pytest.fixture(autouse=True) +def set_event_loop(): + try: + asyncio.get_running_loop() + yield + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + yield + finally: + loop.close() + asyncio.set_event_loop(None)