Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def __init__(
self.socket_like_rpc: Optional[AsyncBidiRpc] = None
self._is_stream_open: bool = False
self.persisted_size: Optional[int] = None
self.is_finalized: bool = False
self.full_obj_server_crc32c: Optional[int] = None
self.object_metadata: Optional[_storage_v2.Object] = None

async def open(self, metadata: Optional[List[Tuple[str, str]]] = None) -> None:
"""Opens the bidi-gRPC connection to read from the object.
Expand Down Expand Up @@ -132,6 +135,18 @@ async def open(self, metadata: Optional[List[Tuple[str, str]]] = None) -> None:
self.generation_number = response.metadata.generation
# update persisted size
self.persisted_size = response.metadata.size
self.object_metadata = response.metadata
if (
hasattr(response.metadata, "finalize_time")
and response.metadata.finalize_time
and response.metadata.finalize_time.second > 0
):
self.is_finalized = True
if (
hasattr(response.metadata, "checksums")
and response.metadata.checksums
):
self.full_obj_server_crc32c = response.metadata.checksums.crc32c

if response and response.read_handle:
self.read_handle = response.read_handle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,25 @@ class _DownloadState:
"""A helper class to track the state of a single range download."""

def __init__(
self, initial_offset: int, initial_length: int, user_buffer: IO[bytes]
self,
initial_offset: int,
initial_length: int,
user_buffer: IO[bytes],
is_full_object_read: bool = False,
enable_checksum: bool = True,
):
self.initial_offset = initial_offset
self.initial_length = initial_length
self.user_buffer = user_buffer
self.bytes_written = 0
self.next_expected_offset = initial_offset
self.is_complete = False
self.is_full_object_read = is_full_object_read
self.rolling_checksum = (
Comment thread
kalragauri marked this conversation as resolved.
google_crc32c.Checksum()
if (is_full_object_read and enable_checksum)
else None
)


class _ReadResumptionStrategy(_BaseResumptionStrategy):
Expand Down Expand Up @@ -90,6 +101,7 @@ def update_state_from_response(
)

download_states = state["download_states"]
checksum_enabled = state.get("enable_checksum", True)

for object_data_range in proto.object_data_ranges:
# Ignore empty ranges or ranges for IDs not in our state
Expand Down Expand Up @@ -125,7 +137,7 @@ def update_state_from_response(
checksummed_data = object_data_range.checksummed_data
data = checksummed_data.content

if checksummed_data.HasField("crc32c"):
if checksum_enabled and checksummed_data.HasField("crc32c"):
server_checksum = checksummed_data.crc32c
client_checksum = google_crc32c.value(data)
if server_checksum != client_checksum:
Expand All @@ -138,10 +150,14 @@ def update_state_from_response(
# Update State & Write Data
chunk_size = len(data)
read_state.user_buffer.write(data)

# Commit updates only after the write succeeds
if checksum_enabled and read_state.rolling_checksum is not None:
read_state.rolling_checksum.update(data)
read_state.bytes_written += chunk_size
read_state.next_expected_offset += chunk_size

# Final Byte Count Verification
# Final Byte Count & Full Object Checksum Verification
if object_data_range.range_end:
read_state.is_complete = True
if (
Expand All @@ -154,6 +170,26 @@ def update_state_from_response(
f"Expected {read_state.initial_length}, got {read_state.bytes_written}",
)

# Perform full-object checksum verification once the stream finishes.
if (
read_state.is_full_object_read
and checksum_enabled
and read_state.rolling_checksum is not None
):
full_obj_server_crc32c = state.get("full_obj_server_crc32c")
if full_obj_server_crc32c is not None:
# Use standard big-endian byte conversion to retrieve the rolling checksum value.
client_checksum = int.from_bytes(
read_state.rolling_checksum.digest(),
byteorder="big",
)
if client_checksum != full_obj_server_crc32c:
raise DataCorruption(
response,
f"Full object checksum mismatch for read_id {read_id}. "
f"Server authoritative crc32c: {full_obj_server_crc32c}, client calculated rolling: {client_checksum}.",
)

async def recover_state_on_failure(self, error: Exception, state: Any) -> None:
"""Handles BidiReadObjectRedirectedError for reads."""
routing_token, read_handle = _handle_redirect(error)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,48 @@ def test_initialization(self):
self.assertEqual(state.bytes_written, 0)
self.assertEqual(state.next_expected_offset, initial_offset)
self.assertFalse(state.is_complete)
self.assertFalse(state.is_full_object_read)
self.assertIsNone(state.rolling_checksum)

def test_initialization_with_full_object_read(self):
"""Test that _DownloadState initializes correctly when is_full_object_read is True."""
initial_offset = 10
initial_length = 100
user_buffer = io.BytesIO()
state_full = _DownloadState(
initial_offset, initial_length, user_buffer, is_full_object_read=True
)

self.assertEqual(state_full.initial_offset, initial_offset)
self.assertEqual(state_full.initial_length, initial_length)
self.assertEqual(state_full.user_buffer, user_buffer)
self.assertEqual(state_full.bytes_written, 0)
self.assertEqual(state_full.next_expected_offset, initial_offset)
self.assertFalse(state_full.is_complete)
self.assertTrue(state_full.is_full_object_read)
self.assertIsNotNone(state_full.rolling_checksum)

def test_initialization_with_full_object_read_and_checksum_disabled(self):
"""Test that _DownloadState does not initialize rolling_checksum when enable_checksum is False."""
initial_offset = 10
initial_length = 100
user_buffer = io.BytesIO()
state_full = _DownloadState(
initial_offset,
initial_length,
user_buffer,
is_full_object_read=True,
enable_checksum=False,
)

self.assertEqual(state_full.initial_offset, initial_offset)
self.assertEqual(state_full.initial_length, initial_length)
self.assertEqual(state_full.user_buffer, user_buffer)
self.assertEqual(state_full.bytes_written, 0)
self.assertEqual(state_full.next_expected_offset, initial_offset)
self.assertFalse(state_full.is_complete)
self.assertTrue(state_full.is_full_object_read)
self.assertIsNone(state_full.rolling_checksum)


class TestReadResumptionStrategy(unittest.TestCase):
Expand All @@ -53,12 +95,24 @@ def setUp(self):

self.state = {"download_states": {}, "read_handle": None, "routing_token": None}

def _add_download(self, read_id, offset=0, length=100, buffer=None):
def _add_download(
self,
read_id,
offset=0,
length=100,
buffer=None,
is_full_object_read=False,
enable_checksum=True,
):
"""Helper to inject a download state into the correct nested location."""
if buffer is None:
buffer = io.BytesIO()
state = _DownloadState(
initial_offset=offset, initial_length=length, user_buffer=buffer
initial_offset=offset,
initial_length=length,
user_buffer=buffer,
is_full_object_read=is_full_object_read,
enable_checksum=enable_checksum,
)
self.state["download_states"][read_id] = state
return state
Expand Down Expand Up @@ -358,3 +412,61 @@ async def run():

# Token should remain unchanged
self.assertEqual(self.state["routing_token"], "existing-token")

def test_update_state_full_object_checksum_success(self):
"""Test that full object checksum verification succeeds on range_end."""
read_state = self._add_download(
_READ_ID, offset=0, length=9, is_full_object_read=True
)
self.state["enable_checksum"] = True
self.state["full_obj_server_crc32c"] = google_crc32c.value(b"testdata1")

resp1 = self._create_response(b"test", _READ_ID, offset=0)
self.strategy.update_state_from_response(resp1, self.state)

resp2 = self._create_response(b"data1", _READ_ID, offset=4, range_end=True)
self.strategy.update_state_from_response(resp2, self.state)

self.assertTrue(read_state.is_complete)
self.assertEqual(read_state.bytes_written, 9)

def test_update_state_full_object_checksum_failure(self):
"""Test that full object checksum verification raises DataCorruption on mismatch at range_end."""
self._add_download(_READ_ID, offset=0, length=9, is_full_object_read=True)
self.state["enable_checksum"] = True
self.state["full_obj_server_crc32c"] = 111111 # Wrong server checksum!

resp1 = self._create_response(b"test", _READ_ID, offset=0)
self.strategy.update_state_from_response(resp1, self.state)

resp2 = self._create_response(b"data1", _READ_ID, offset=4, range_end=True)
with self.assertRaisesRegex(DataCorruption, "Full object checksum mismatch"):
self.strategy.update_state_from_response(resp2, self.state)

def test_update_state_checksum_mismatch_ignored_when_disabled(self):
"""Test that a CRC32C mismatch is ignored when enable_checksum is False."""
self._add_download(_READ_ID)
self.state["enable_checksum"] = False
response = self._create_response(b"data", _READ_ID, offset=0, crc=999999)

# Should NOT raise DataCorruption!
self.strategy.update_state_from_response(response, self.state)
Comment thread
chandra-siri marked this conversation as resolved.

def test_update_state_full_object_checksum_mismatch_ignored_when_disabled(self):
"""Test that a full-object CRC32C mismatch is ignored when enable_checksum is False."""
self._add_download(
_READ_ID,
offset=0,
length=9,
is_full_object_read=True,
enable_checksum=False,
)
self.state["enable_checksum"] = False
self.state["full_obj_server_crc32c"] = 111111 # Wrong server checksum!

resp1 = self._create_response(b"test", _READ_ID, offset=0)
self.strategy.update_state_from_response(resp1, self.state)

resp2 = self._create_response(b"data1", _READ_ID, offset=4, range_end=True)
# Should NOT raise DataCorruption!
self.strategy.update_state_from_response(resp2, self.state)
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ async def instantiate_read_obj_stream(mock_client, mock_cls_async_bidi_rpc, open
socket_like_rpc.open = AsyncMock()

recv_response = mock.MagicMock(spec=_storage_v2.BidiReadObjectResponse)
recv_response.metadata = mock.MagicMock(spec=_storage_v2.Object)
recv_response.metadata = mock.MagicMock()
recv_response.metadata.generation = _TEST_GENERATION_NUMBER
recv_response.metadata.size = _TEST_OBJECT_SIZE
recv_response.metadata.finalize_time.second = 30
recv_response.metadata.checksums.crc32c = 98765
recv_response.read_handle = _TEST_READ_HANDLE
socket_like_rpc.recv = AsyncMock(return_value=recv_response)

Expand Down Expand Up @@ -130,6 +132,8 @@ async def test_open(mock_client, mock_cls_async_bidi_rpc):
assert read_obj_stream.generation_number == _TEST_GENERATION_NUMBER
assert read_obj_stream.read_handle == _TEST_READ_HANDLE
assert read_obj_stream.persisted_size == _TEST_OBJECT_SIZE
assert read_obj_stream.is_finalized is True
assert read_obj_stream.full_obj_server_crc32c == 98765
assert read_obj_stream.is_stream_open


Expand Down Expand Up @@ -381,3 +385,36 @@ async def test_recv_updates_read_handle_on_refresh(

await stream.recv()
assert stream.read_handle == refreshed_handle


@mock.patch("google.cloud.storage.asyncio.async_read_object_stream.AsyncBidiRpc")
@mock.patch(
"google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client"
)
@pytest.mark.asyncio
async def test_open_unfinalized_object_skips_checksum(
mock_client, mock_cls_async_bidi_rpc
):
socket_like_rpc = AsyncMock()
mock_cls_async_bidi_rpc.return_value = socket_like_rpc
socket_like_rpc.open = AsyncMock()

recv_response = mock.MagicMock(spec=_storage_v2.BidiReadObjectResponse)
recv_response.metadata = mock.MagicMock()
recv_response.metadata.generation = _TEST_GENERATION_NUMBER
recv_response.metadata.size = _TEST_OBJECT_SIZE
recv_response.metadata.finalize_time.second = 0 # NOT finalized!
recv_response.metadata.checksums.crc32c = 98765
recv_response.read_handle = _TEST_READ_HANDLE
socket_like_rpc.recv = AsyncMock(return_value=recv_response)

read_obj_stream = _AsyncReadObjectStream(
client=mock_client,
bucket_name=_TEST_BUCKET_NAME,
object_name=_TEST_OBJECT_NAME,
)

await read_obj_stream.open()

assert read_obj_stream.is_finalized is False
assert read_obj_stream.full_obj_server_crc32c is None
32 changes: 32 additions & 0 deletions packages/google-cloud-storage/tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import pytest


@pytest.fixture(autouse=True)
def set_event_loop():
try:
asyncio.get_running_loop()
yield
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
yield
finally:
loop.close()
asyncio.set_event_loop(None)
Loading