Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from google.api_core.exceptions import Aborted
from google.cloud._helpers import _date_from_iso8601_date
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper
from google.protobuf.message import Message
from google.protobuf.message import DecodeError, Message
from google.protobuf.struct_pb2 import ListValue, Value
from google.rpc.error_details_pb2 import RetryInfo

Expand Down Expand Up @@ -76,7 +76,7 @@

GOOGLE_CLOUD_REGION_GLOBAL = "global"

log = logging.getLogger(__name__)
_LOGGER = logging.getLogger(__name__)

_cloud_region: str = None

Expand Down Expand Up @@ -122,7 +122,7 @@ def _get_cloud_region() -> str:
else:
_cloud_region = GOOGLE_CLOUD_REGION_GLOBAL
except Exception as e:
log.warning(
_LOGGER.warning(
"Failed to detect GCP resource location for Spanner metrics, defaulting to 'global'. Error: %s",
e,
)
Expand Down Expand Up @@ -603,8 +603,14 @@ def _parse_proto(value_pb, column_info, field_name):
default_proto_message = column_info.get(field_name)
if isinstance(default_proto_message, Message):
proto_message = type(default_proto_message)()
proto_message.ParseFromString(bytes_value)
return proto_message
try:
proto_message.ParseFromString(bytes_value)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a way to enforce nesting level?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The standard Python Protobuf API (ParseFromString) does not expose a parameter to enforce a custom recursion or nesting limit during parsing.
Although, we did consider using sys.setrecursionlimit() to force a lower limit (e.g. 100) right before parsing but had to reject it because sys.setrecursionlimit is process-wide. In a multi-threaded environment, lowering the limit for this one call could cause arbitrary valid code run by other threads to crash.

I couldn't find any other thread-safe way to enforce a custom limit during parsing in Python Protobuf, the most robust solution is to catch the errors the parser naturally throws (DecodeError and RecursionError) and fall back.

In Java, we can explicitly enforce and set a custom nesting level while parsing which is not possible in Python.

return proto_message
except (DecodeError, RecursionError):
_LOGGER.warning(
"Field could not be parsed as Proto due to excessive nesting/corruption. Returning raw bytes."
)
return bytes_value
return bytes_value


Expand Down
53 changes: 53 additions & 0 deletions packages/google-cloud-spanner/tests/unit/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,59 @@ def test_w_proto_message(self):
self._callFUT(value_pb, field_type, field_name, column_info), VALUE
)

def test_w_proto_message_decode_error(self):
import base64
from unittest import mock

from google.protobuf.message import DecodeError
from google.protobuf.struct_pb2 import Value

from google.cloud.spanner_v1 import Type, TypeCode

from .testdata import singer_pb2

VALUE = singer_pb2.SingerInfo(singer_id=1, nationality="Canadian")
field_type = Type(code=TypeCode.PROTO)
field_name = "proto_message_column"
raw_bytes = VALUE.SerializeToString()
value_pb = Value(string_value=base64.b64encode(raw_bytes).decode("utf-8"))
column_info = {"proto_message_column": singer_pb2.SingerInfo()}

# Mock ParseFromString to raise DecodeError
with mock.patch.object(
singer_pb2.SingerInfo,
"ParseFromString",
side_effect=DecodeError("Mock Decode Error"),
):
result = self._callFUT(value_pb, field_type, field_name, column_info)
# Should return raw bytes
self.assertEqual(result, raw_bytes)

def test_w_proto_message_recursion_error(self):
import base64
from unittest import mock

from google.protobuf.struct_pb2 import Value

from google.cloud.spanner_v1 import Type, TypeCode

from .testdata import singer_pb2

VALUE = singer_pb2.SingerInfo(singer_id=1, nationality="Canadian")
field_type = Type(code=TypeCode.PROTO)
field_name = "proto_message_column"
raw_bytes = VALUE.SerializeToString()
value_pb = Value(string_value=base64.b64encode(raw_bytes).decode("utf-8"))
column_info = {"proto_message_column": singer_pb2.SingerInfo()}

with mock.patch.object(
singer_pb2.SingerInfo,
"ParseFromString",
side_effect=RecursionError("Mock Recursion Error"),
):
result = self._callFUT(value_pb, field_type, field_name, column_info)
self.assertEqual(result, raw_bytes)

def test_w_proto_enum(self):
from google.protobuf.struct_pb2 import Value

Expand Down
Loading