Skip to content

Commit

Permalink
Merge branch 'main' into trace-annotation-SessionPools
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em authored Dec 3, 2024
2 parents 5bf38b7 + ccae6e0 commit 498a70a
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 79 deletions.
178 changes: 131 additions & 47 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,66 +266,69 @@ def _parse_value_pb(value_pb, field_type, field_name, column_info=None):
:returns: value extracted from value_pb
:raises ValueError: if unknown type is passed
"""
decoder = _get_type_decoder(field_type, field_name, column_info)
return _parse_nullable(value_pb, decoder)


def _get_type_decoder(field_type, field_name, column_info=None):
"""Returns a function that converts a Value protobuf to cell data.
:type field_type: :class:`~google.cloud.spanner_v1.types.Type`
:param field_type: type code for the value
:type field_name: str
:param field_name: column name
:type column_info: dict
:param column_info: (Optional) dict of column name and column information.
An object where column names as keys and custom objects as corresponding
values for deserialization. It's specifically useful for data types like
protobuf where deserialization logic is on user-specific code. When provided,
the custom object enables deserialization of backend-received column data.
If not provided, data remains serialized as bytes for Proto Messages and
integer for Proto Enums.
:rtype: a function that takes a single protobuf value as an input argument
:returns: a function that can be used to extract a value from a protobuf value
:raises ValueError: if unknown type is passed
"""

type_code = field_type.code
if value_pb.HasField("null_value"):
return None
if type_code == TypeCode.STRING:
return value_pb.string_value
return _parse_string
elif type_code == TypeCode.BYTES:
return value_pb.string_value.encode("utf8")
return _parse_bytes
elif type_code == TypeCode.BOOL:
return value_pb.bool_value
return _parse_bool
elif type_code == TypeCode.INT64:
return int(value_pb.string_value)
return _parse_int64
elif type_code == TypeCode.FLOAT64:
if value_pb.HasField("string_value"):
return float(value_pb.string_value)
else:
return value_pb.number_value
return _parse_float
elif type_code == TypeCode.FLOAT32:
if value_pb.HasField("string_value"):
return float(value_pb.string_value)
else:
return value_pb.number_value
return _parse_float
elif type_code == TypeCode.DATE:
return _date_from_iso8601_date(value_pb.string_value)
return _parse_date
elif type_code == TypeCode.TIMESTAMP:
DatetimeWithNanoseconds = datetime_helpers.DatetimeWithNanoseconds
return DatetimeWithNanoseconds.from_rfc3339(value_pb.string_value)
elif type_code == TypeCode.ARRAY:
return [
_parse_value_pb(
item_pb, field_type.array_element_type, field_name, column_info
)
for item_pb in value_pb.list_value.values
]
elif type_code == TypeCode.STRUCT:
return [
_parse_value_pb(
item_pb, field_type.struct_type.fields[i].type_, field_name, column_info
)
for (i, item_pb) in enumerate(value_pb.list_value.values)
]
return _parse_timestamp
elif type_code == TypeCode.NUMERIC:
return decimal.Decimal(value_pb.string_value)
return _parse_numeric
elif type_code == TypeCode.JSON:
return JsonObject.from_str(value_pb.string_value)
return _parse_json
elif type_code == TypeCode.PROTO:
bytes_value = base64.b64decode(value_pb.string_value)
if column_info is not None and column_info.get(field_name) is not None:
default_proto_message = column_info.get(field_name)
if isinstance(default_proto_message, Message):
proto_message = type(default_proto_message)()
proto_message.ParseFromString(bytes_value)
return proto_message
return bytes_value
return lambda value_pb: _parse_proto(value_pb, column_info, field_name)
elif type_code == TypeCode.ENUM:
int_value = int(value_pb.string_value)
if column_info is not None and column_info.get(field_name) is not None:
proto_enum = column_info.get(field_name)
if isinstance(proto_enum, EnumTypeWrapper):
return proto_enum.Name(int_value)
return int_value
return lambda value_pb: _parse_proto_enum(value_pb, column_info, field_name)
elif type_code == TypeCode.ARRAY:
element_decoder = _get_type_decoder(
field_type.array_element_type, field_name, column_info
)
return lambda value_pb: _parse_array(value_pb, element_decoder)
elif type_code == TypeCode.STRUCT:
element_decoders = [
_get_type_decoder(item_field.type_, field_name, column_info)
for item_field in field_type.struct_type.fields
]
return lambda value_pb: _parse_struct(value_pb, element_decoders)
else:
raise ValueError("Unknown type: %s" % (field_type,))

Expand All @@ -351,6 +354,87 @@ def _parse_list_value_pbs(rows, row_type):
return result


def _parse_string(value_pb) -> str:
return value_pb.string_value


def _parse_bytes(value_pb):
return value_pb.string_value.encode("utf8")


def _parse_bool(value_pb) -> bool:
return value_pb.bool_value


def _parse_int64(value_pb) -> int:
return int(value_pb.string_value)


def _parse_float(value_pb) -> float:
if value_pb.HasField("string_value"):
return float(value_pb.string_value)
else:
return value_pb.number_value


def _parse_date(value_pb):
return _date_from_iso8601_date(value_pb.string_value)


def _parse_timestamp(value_pb):
DatetimeWithNanoseconds = datetime_helpers.DatetimeWithNanoseconds
return DatetimeWithNanoseconds.from_rfc3339(value_pb.string_value)


def _parse_numeric(value_pb):
return decimal.Decimal(value_pb.string_value)


def _parse_json(value_pb):
return JsonObject.from_str(value_pb.string_value)


def _parse_proto(value_pb, column_info, field_name):
bytes_value = base64.b64decode(value_pb.string_value)
if column_info is not None and column_info.get(field_name) is not None:
default_proto_message = column_info.get(field_name)
if isinstance(default_proto_message, Message):
proto_message = type(default_proto_message)()
proto_message.ParseFromString(bytes_value)
return proto_message
return bytes_value


def _parse_proto_enum(value_pb, column_info, field_name):
int_value = int(value_pb.string_value)
if column_info is not None and column_info.get(field_name) is not None:
proto_enum = column_info.get(field_name)
if isinstance(proto_enum, EnumTypeWrapper):
return proto_enum.Name(int_value)
return int_value


def _parse_array(value_pb, element_decoder) -> []:
return [
_parse_nullable(item_pb, element_decoder)
for item_pb in value_pb.list_value.values
]


def _parse_struct(value_pb, element_decoders):
return [
_parse_nullable(item_pb, element_decoders[i])
for (i, item_pb) in enumerate(value_pb.list_value.values)
]


def _parse_nullable(value_pb, decoder):
if value_pb.HasField("null_value"):
return None
else:
return decoder(value_pb)


class _SessionWrapper(object):
"""Base class for objects wrapping a session.
Expand Down
54 changes: 48 additions & 6 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def read(
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
column_info=None,
lazy_decode=False,
):
"""Perform a ``StreamingRead`` API request for rows in a table.
Expand Down Expand Up @@ -255,6 +256,18 @@ def read(
If not provided, data remains serialized as bytes for Proto Messages and
integer for Proto Enums.
:type lazy_decode: bool
:param lazy_decode:
(Optional) If this argument is set to ``true``, the iterator
returns the underlying protobuf values instead of decoded Python
objects. This reduces the time that is needed to iterate through
large result sets. The application is responsible for decoding
the data that is needed. The returned row iterator contains two
functions that can be used for this. ``iterator.decode_row(row)``
decodes all the columns in the given row to an array of Python
objects. ``iterator.decode_column(row, column_index)`` decodes one
specific column in the given row.
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
Expand Down Expand Up @@ -330,10 +343,15 @@ def read(
self._read_request_count += 1
if self._multi_use:
return StreamedResultSet(
iterator, source=self, column_info=column_info
iterator,
source=self,
column_info=column_info,
lazy_decode=lazy_decode,
)
else:
return StreamedResultSet(iterator, column_info=column_info)
return StreamedResultSet(
iterator, column_info=column_info, lazy_decode=lazy_decode
)
else:
iterator = _restart_on_unavailable(
restart,
Expand All @@ -348,9 +366,13 @@ def read(
self._read_request_count += 1

if self._multi_use:
return StreamedResultSet(iterator, source=self, column_info=column_info)
return StreamedResultSet(
iterator, source=self, column_info=column_info, lazy_decode=lazy_decode
)
else:
return StreamedResultSet(iterator, column_info=column_info)
return StreamedResultSet(
iterator, column_info=column_info, lazy_decode=lazy_decode
)

def execute_sql(
self,
Expand All @@ -366,6 +388,7 @@ def execute_sql(
data_boost_enabled=False,
directed_read_options=None,
column_info=None,
lazy_decode=False,
):
"""Perform an ``ExecuteStreamingSql`` API request.
Expand Down Expand Up @@ -438,6 +461,18 @@ def execute_sql(
If not provided, data remains serialized as bytes for Proto Messages and
integer for Proto Enums.
:type lazy_decode: bool
:param lazy_decode:
(Optional) If this argument is set to ``true``, the iterator
returns the underlying protobuf values instead of decoded Python
objects. This reduces the time that is needed to iterate through
large result sets. The application is responsible for decoding
the data that is needed. The returned row iterator contains two
functions that can be used for this. ``iterator.decode_row(row)``
decodes all the columns in the given row to an array of Python
objects. ``iterator.decode_column(row, column_index)`` decodes one
specific column in the given row.
:raises ValueError:
for reuse of single-use snapshots, or if a transaction ID is
already pending for multiple-use snapshots.
Expand Down Expand Up @@ -517,6 +552,7 @@ def execute_sql(
trace_attributes,
column_info,
observability_options,
lazy_decode=lazy_decode,
)
else:
return self._get_streamed_result_set(
Expand All @@ -525,6 +561,7 @@ def execute_sql(
trace_attributes,
column_info,
observability_options,
lazy_decode=lazy_decode,
)

def _get_streamed_result_set(
Expand All @@ -534,6 +571,7 @@ def _get_streamed_result_set(
trace_attributes,
column_info,
observability_options=None,
lazy_decode=False,
):
iterator = _restart_on_unavailable(
restart,
Expand All @@ -548,9 +586,13 @@ def _get_streamed_result_set(
self._execute_sql_count += 1

if self._multi_use:
return StreamedResultSet(iterator, source=self, column_info=column_info)
return StreamedResultSet(
iterator, source=self, column_info=column_info, lazy_decode=lazy_decode
)
else:
return StreamedResultSet(iterator, column_info=column_info)
return StreamedResultSet(
iterator, column_info=column_info, lazy_decode=lazy_decode
)

def partition_read(
self,
Expand Down
Loading

0 comments on commit 498a70a

Please sign in to comment.