diff --git a/packages/google-cloud-spanner/google/cloud/spanner_v1/_async/streamed.py b/packages/google-cloud-spanner/google/cloud/spanner_v1/_async/streamed.py index 3104274ced2c..154d4d394d39 100644 --- a/packages/google-cloud-spanner/google/cloud/spanner_v1/_async/streamed.py +++ b/packages/google-cloud-spanner/google/cloud/spanner_v1/_async/streamed.py @@ -129,16 +129,35 @@ def _merge_values(self, values): decoders = self._decoders width = len(self.fields) index = len(self._current_row) - for value in values: - if self._lazy_decode: - self._current_row.append(value) - else: - self._current_row.append(_parse_nullable(value, decoders[index])) - index += 1 - if index == width: - self._rows.append(self._current_row) - self._current_row = [] - index = 0 + current_row = self._current_row + rows = self._rows + + current_row_append = current_row.append + rows_append = rows.append + + if self._lazy_decode: + for value in values: + current_row_append(value) + index += 1 + if index == width: + rows_append(current_row) + current_row = [] + current_row_append = current_row.append + index = 0 + else: + for value in values: + if value.HasField("null_value"): + current_row_append(None) + else: + current_row_append(decoders[index](value)) + index += 1 + if index == width: + rows_append(current_row) + current_row = [] + current_row_append = current_row.append + index = 0 + + self._current_row = current_row @CrossSync.convert async def _consume_next(self): diff --git a/packages/google-cloud-spanner/google/cloud/spanner_v1/_helpers.py b/packages/google-cloud-spanner/google/cloud/spanner_v1/_helpers.py index dfcf6721af82..433b8e016518 100644 --- a/packages/google-cloud-spanner/google/cloud/spanner_v1/_helpers.py +++ b/packages/google-cloud-spanner/google/cloud/spanner_v1/_helpers.py @@ -19,6 +19,7 @@ import decimal import logging import math +import operator import threading import time import uuid @@ -26,7 +27,6 @@ from google.api_core import datetime_helpers from google.api_core.exceptions import Aborted -from google.cloud._helpers import _date_from_iso8601_date from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper from google.protobuf.message import DecodeError, Message from google.protobuf.struct_pb2 import ListValue, Value @@ -465,6 +465,12 @@ def _parse_value_pb(value_pb, field_type, field_name, column_info=None): return _parse_nullable(value_pb, decoder) +_date_fromisoformat = datetime.date.fromisoformat +_Decimal = decimal.Decimal +_json_from_str = JsonObject.from_str +_uuid_UUID = uuid.UUID + + def _get_type_decoder(field_type, field_name, column_info=None): """Returns a function that converts a Value protobuf to cell data. @@ -490,27 +496,35 @@ def _get_type_decoder(field_type, field_name, column_info=None): type_code = field_type.code if type_code == TypeCode.STRING: - return _parse_string + return operator.attrgetter("string_value") elif type_code == TypeCode.BYTES: - return _parse_bytes + return lambda value_pb: value_pb.string_value.encode("utf8") elif type_code == TypeCode.BOOL: - return _parse_bool + return operator.attrgetter("bool_value") elif type_code == TypeCode.INT64: - return _parse_int64 + return lambda value_pb: int(value_pb.string_value) elif type_code == TypeCode.FLOAT64: - return _parse_float + return ( + lambda value_pb: float(value_pb.string_value) + if value_pb.HasField("string_value") + else value_pb.number_value + ) elif type_code == TypeCode.FLOAT32: - return _parse_float + return ( + lambda value_pb: float(value_pb.string_value) + if value_pb.HasField("string_value") + else value_pb.number_value + ) elif type_code == TypeCode.DATE: - return _parse_date + return lambda value_pb: _date_fromisoformat(value_pb.string_value) elif type_code == TypeCode.TIMESTAMP: return _parse_timestamp elif type_code == TypeCode.NUMERIC: - return _parse_numeric + return lambda value_pb: _Decimal(value_pb.string_value) elif type_code == TypeCode.JSON: - return _parse_json + return lambda value_pb: _json_from_str(value_pb.string_value) elif type_code == TypeCode.UUID: - return _parse_uuid + return lambda value_pb: _uuid_UUID(value_pb.string_value) elif type_code == TypeCode.PROTO: return lambda value_pb: _parse_proto(value_pb, column_info, field_name) elif type_code == TypeCode.ENUM: @@ -553,48 +567,73 @@ def _parse_list_value_pbs(rows, row_type): return result -def _parse_string(value_pb) -> str: - return value_pb.string_value - - -def _parse_bytes(value_pb): - return value_pb.string_value.encode("utf8") - - -def _parse_bool(value_pb) -> bool: - return value_pb.bool_value - - -def _parse_int64(value_pb) -> int: - return int(value_pb.string_value) - - -def _parse_float(value_pb) -> float: - if value_pb.HasField("string_value"): - return float(value_pb.string_value) - else: - return value_pb.number_value - - -def _parse_date(value_pb): - return _date_from_iso8601_date(value_pb.string_value) +_POWERS_OF_10 = ( + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000, +) def _parse_timestamp(value_pb): - DatetimeWithNanoseconds = datetime_helpers.DatetimeWithNanoseconds - return DatetimeWithNanoseconds.from_rfc3339(value_pb.string_value) - - -def _parse_numeric(value_pb): - return decimal.Decimal(value_pb.string_value) - - -def _parse_json(value_pb): - return JsonObject.from_str(value_pb.string_value) - - -def _parse_uuid(value_pb): - return uuid.UUID(value_pb.string_value) + val = value_pb.string_value + try: + if len(val) < 20 or val[10] != "T": + raise ValueError() + no_fraction = val[:19] + bare = datetime.datetime.fromisoformat(no_fraction) + if val[19] == ".": + if val.endswith("Z"): + offset = "Z" + fraction = val[20:-1] + elif val[-6] in ("+", "-"): + offset = val[-6:] + fraction = val[20:-6] + else: + raise ValueError() + if not fraction or len(fraction) > 9 or not fraction.isdigit(): + raise ValueError() + scale = 9 - len(fraction) + nanos = int(fraction) * _POWERS_OF_10[scale] + else: + nanos = 0 + if val.endswith("Z"): + offset = "Z" + elif val[-6] in ("+", "-"): + offset = val[-6:] + else: + raise ValueError() + + if offset != "Z": + sign = offset[0] + hours = int(offset[1:3]) + minutes = int(offset[4:6]) + if offset[3] != ":": + raise ValueError() + delta = datetime.timedelta(hours=hours, minutes=minutes) + if sign == "-": + delta = -delta + tzinfo = datetime.timezone(delta) + bare = bare.replace(tzinfo=tzinfo).astimezone(datetime.timezone.utc) + + return datetime_helpers.DatetimeWithNanoseconds( + bare.year, + bare.month, + bare.day, + bare.hour, + bare.minute, + bare.second, + nanosecond=nanos, + tzinfo=datetime.timezone.utc, + ) + except (IndexError, ValueError) as e: + raise ValueError("Timestamp: {} does not match pattern".format(val)) from e def _parse_proto(value_pb, column_info, field_name): diff --git a/packages/google-cloud-spanner/google/cloud/spanner_v1/data_types.py b/packages/google-cloud-spanner/google/cloud/spanner_v1/data_types.py index 3c3a7f6bfe32..59a2268e98a7 100644 --- a/packages/google-cloud-spanner/google/cloud/spanner_v1/data_types.py +++ b/packages/google-cloud-spanner/google/cloud/spanner_v1/data_types.py @@ -99,6 +99,11 @@ def serialize(self): return json.dumps(self, sort_keys=True, separators=(",", ":")) +_INTERVAL_PATTERN = re.compile( + r"^P(-?\d+Y)?(-?\d+M)?(-?\d+D)?(T(-?\d+H)?(-?\d+M)?(-?((\d+([.,]\d{1,9})?)|([.,]\d{1,9}))S)?)?$" +) + + @dataclass class Interval: """Represents a Spanner INTERVAL type. @@ -187,8 +192,7 @@ def __str__(self) -> str: @classmethod def from_str(cls, s: str) -> "Interval": """Parse an ISO8601 duration format string into an Interval.""" - pattern = r"^P(-?\d+Y)?(-?\d+M)?(-?\d+D)?(T(-?\d+H)?(-?\d+M)?(-?((\d+([.,]\d{1,9})?)|([.,]\d{1,9}))S)?)?$" - match = re.match(pattern, s) + match = _INTERVAL_PATTERN.match(s) if not match or len(s) == 1: raise ValueError(f"Invalid interval format: {s}") diff --git a/packages/google-cloud-spanner/google/cloud/spanner_v1/streamed.py b/packages/google-cloud-spanner/google/cloud/spanner_v1/streamed.py index 59d8d8b746d5..10cadbab0a49 100644 --- a/packages/google-cloud-spanner/google/cloud/spanner_v1/streamed.py +++ b/packages/google-cloud-spanner/google/cloud/spanner_v1/streamed.py @@ -35,8 +35,7 @@ class StreamedResultSet(object): instances. :type source: :class:`~google.cloud.spanner_v1.snapshot.Snapshot` - :param source: Deprecated. Snapshot from which the result set was fetched. - """ + :param source: Deprecated. Snapshot from which the result set was fetched.""" def __init__( self, @@ -117,16 +116,32 @@ def _merge_values(self, values): decoders = self._decoders width = len(self.fields) index = len(self._current_row) - for value in values: - if self._lazy_decode: - self._current_row.append(value) - else: - self._current_row.append(_parse_nullable(value, decoders[index])) - index += 1 - if index == width: - self._rows.append(self._current_row) - self._current_row = [] - index = 0 + current_row = self._current_row + rows = self._rows + current_row_append = current_row.append + rows_append = rows.append + if self._lazy_decode: + for value in values: + current_row_append(value) + index += 1 + if index == width: + rows_append(current_row) + current_row = [] + current_row_append = current_row.append + index = 0 + else: + for value in values: + if value.HasField("null_value"): + current_row_append(None) + else: + current_row_append(decoders[index](value)) + index += 1 + if index == width: + rows_append(current_row) + current_row = [] + current_row_append = current_row.append + index = 0 + self._current_row = current_row def _consume_next(self): """Consume the next partial result set from the stream. diff --git a/packages/google-cloud-spanner/tests/unit/test__helpers.py b/packages/google-cloud-spanner/tests/unit/test__helpers.py index b81e745d418f..45493f6c1a88 100644 --- a/packages/google-cloud-spanner/tests/unit/test__helpers.py +++ b/packages/google-cloud-spanner/tests/unit/test__helpers.py @@ -653,6 +653,96 @@ def test_w_timestamp_w_nanos(self): self.assertIsInstance(parsed, datetime_helpers.DatetimeWithNanoseconds) self.assertEqual(parsed, value) + def test_w_timestamp_w_offset(self): + from google.api_core import datetime_helpers + from google.protobuf.struct_pb2 import Value + + from google.cloud.spanner_v1 import Type, TypeCode + + value_pb = Value(string_value="2016-12-20T12:13:47.123456789+01:00") + field_type = Type(code=TypeCode.TIMESTAMP) + field_name = "timestamp_column" + + expected = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 11, 13, 47, nanosecond=123456789, tzinfo=timezone.utc + ) + + parsed = self._callFUT(value_pb, field_type, field_name) + self.assertIsInstance(parsed, datetime_helpers.DatetimeWithNanoseconds) + self.assertEqual(parsed, expected) + + value_pb_neg = Value(string_value="2016-12-20T12:13:47.123456789-05:00") + expected_neg = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 17, 13, 47, nanosecond=123456789, tzinfo=timezone.utc + ) + parsed_neg = self._callFUT(value_pb_neg, field_type, field_name) + self.assertEqual(parsed_neg, expected_neg) + + def test_w_timestamp_various_formats(self): + from google.api_core import datetime_helpers + from google.protobuf.struct_pb2 import Value + + from google.cloud.spanner_v1 import Type, TypeCode + + field_type = Type(code=TypeCode.TIMESTAMP) + field_name = "timestamp_column" + + # 1. No seconds fraction, UTC (Z) + value_pb = Value(string_value="2016-12-20T21:13:47Z") + expected = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, nanosecond=0, tzinfo=timezone.utc + ) + parsed = self._callFUT(value_pb, field_type, field_name) + self.assertEqual(parsed, expected) + + # 2. Single digit fraction (nanoseconds), UTC (Z) + value_pb = Value(string_value="2016-12-20T21:13:47.1Z") + expected = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, nanosecond=100000000, tzinfo=timezone.utc + ) + parsed = self._callFUT(value_pb, field_type, field_name) + self.assertEqual(parsed, expected) + + # 3. Milliseconds (3 digits fraction), UTC (Z) + value_pb = Value(string_value="2016-12-20T21:13:47.123Z") + expected = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, nanosecond=123000000, tzinfo=timezone.utc + ) + parsed = self._callFUT(value_pb, field_type, field_name) + self.assertEqual(parsed, expected) + + # 4. Microseconds (6 digits fraction), UTC (Z) + value_pb = Value(string_value="2016-12-20T21:13:47.123456Z") + expected = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, nanosecond=123456000, tzinfo=timezone.utc + ) + parsed = self._callFUT(value_pb, field_type, field_name) + self.assertEqual(parsed, expected) + + def test_w_timestamp_invalid_formats(self): + from google.protobuf.struct_pb2 import Value + + from google.cloud.spanner_v1 import Type, TypeCode + + field_type = Type(code=TypeCode.TIMESTAMP) + field_name = "timestamp_column" + + invalid_strings = [ + "2016-12-20T21:13:47", # Missing timezone offset + "2016-12-20 21:13:47Z", # Space instead of 'T' separator + "2016-12-20T21:13:47+0100", # Missing colon in offset + "2016-12-20T21:13:47.1234567890Z", # Too many sub-seconds digits (10 digits) + "2016-12-20T21:13:4Z", # Single digit second + "2016-12-20T21:1:47Z", # Single digit minute + "2016-12-20T2:13:47Z", # Single digit hour + "2016-12-20T21:13:47+1:00", # Single digit hour in offset + ] + + for invalid_string in invalid_strings: + value_pb = Value(string_value=invalid_string) + with self.assertRaises((ValueError, IndexError)): + self._callFUT(value_pb, field_type, field_name) + def test_w_array_empty(self): from google.protobuf.struct_pb2 import ListValue, Value